From fbd5a808f6cb4ab2e73cb90c4699d8b525966e31 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Wed, 4 Mar 2026 17:03:02 +0200 Subject: [PATCH] Store hf_pretrained as properties of Megatron*Bridge classes So that downstream model bridges that need hf_pretrained configs information to build mapping_registry no longer need to override build_conversion_tasks (e.g. GLM 4.5 bridge). Signed-off-by: Hollow Man --- .../bridge/models/conversion/auto_bridge.py | 1 + .../bridge/models/conversion/model_bridge.py | 16 +- .../bridge/models/conversion/peft_bridge.py | 15 +- .../bridge/models/glm/glm45_bridge.py | 171 +++++++++--------- .../bridge/models/glm_vl/glm_45v_bridge.py | 23 ++- .../models/glm/test_glm45_bridge.py | 12 +- .../models/glm_vl/test_glm_45v_bridge.py | 10 +- tests/unit_tests/models/test_auto_bridge.py | 48 ++++- .../models/test_model_bridge_lora.py | 13 +- 9 files changed, 196 insertions(+), 113 deletions(-) diff --git a/src/megatron/bridge/models/conversion/auto_bridge.py b/src/megatron/bridge/models/conversion/auto_bridge.py index 3e0ac4e02a..231ef0e7e4 100644 --- a/src/megatron/bridge/models/conversion/auto_bridge.py +++ b/src/megatron/bridge/models/conversion/auto_bridge.py @@ -406,6 +406,7 @@ def export_adapter_weights( return model_bridge.stream_adapter_weights_megatron_to_hf( dispatch_instance, model, + self.hf_pretrained, cpu=cpu, show_progress=show_progress, ) diff --git a/src/megatron/bridge/models/conversion/model_bridge.py b/src/megatron/bridge/models/conversion/model_bridge.py index c5a6155b12..73b0c89f14 100644 --- a/src/megatron/bridge/models/conversion/model_bridge.py +++ b/src/megatron/bridge/models/conversion/model_bridge.py @@ -960,7 +960,7 @@ def stream_weights_megatron_to_hf( # Collect adapter conversion tasks when merge is requested adapter_tasks_by_base: Dict[str, List[AdapterWeightConversionTask]] = {} if merge_adapter_weights: - adapter_tasks_by_base = self.build_adapter_conversion_tasks(megatron_model) + adapter_tasks_by_base = self.build_adapter_conversion_tasks(hf_pretrained, megatron_model) megatron_to_hf_tasks = conversion_tasks unwrapped_model = unwrap_model(megatron_model)[0] @@ -1205,6 +1205,9 @@ def build_conversion_tasks( if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")): raise ValueError("hf_pretrained.state.source is required for weight ordering") + self.hf_pretrained = hf_pretrained + self.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained + hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys() mapping_registry = self.mapping_registry() @@ -1394,6 +1397,7 @@ def stream_weights_megatron_to_hf( def stream_adapter_weights_megatron_to_hf( dispatch_instance: MegatronModel, megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, cpu: bool = True, show_progress: bool = True, ) -> Iterable[HFWeightTuple]: @@ -1437,7 +1441,8 @@ def _megatron_to_hf_registered_impl( ) -> Iterable[HFWeightTuple]: bridge = bridge_class() - # allow bridge to access model config (config-only shims or raw configs lack .config) + # allow bridge to access model config + bridge.hf_pretrained = hf_pretrained bridge.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained return bridge.stream_weights_megatron_to_hf( @@ -1453,12 +1458,19 @@ def _megatron_to_hf_registered_impl( def _adapter_stream_registered_impl( _, megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, cpu: bool = True, show_progress: bool = True, ) -> Iterable[HFWeightTuple]: bridge = bridge_class() + + # allow bridge to access model config + bridge.hf_pretrained = hf_pretrained + bridge.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained + return bridge.stream_adapter_weights_megatron_to_hf( megatron_model, + hf_pretrained, cpu=cpu, show_progress=show_progress, ) diff --git a/src/megatron/bridge/models/conversion/peft_bridge.py b/src/megatron/bridge/models/conversion/peft_bridge.py index 85646734dd..ad28757cb5 100644 --- a/src/megatron/bridge/models/conversion/peft_bridge.py +++ b/src/megatron/bridge/models/conversion/peft_bridge.py @@ -44,7 +44,12 @@ if TYPE_CHECKING: from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry - from megatron.bridge.models.conversion.model_bridge import HFWeightTuple, MegatronWeightTuple, WeightConversionTask + from megatron.bridge.models.conversion.model_bridge import ( + HFPreTrained, + HFWeightTuple, + MegatronWeightTuple, + WeightConversionTask, + ) MegatronModel = TypeVar("MegatronModel", bound=MegatronModule) @@ -439,7 +444,7 @@ def _construct_adapters_names(self, prefix: str, adapter_key: Optional[str]) -> return linear_in_name, linear_out_name def build_adapter_conversion_tasks( - self, megatron_model: Union[MegatronModel, List[MegatronModel]] + self, hf_pretrained: HFPreTrained, megatron_model: Union[MegatronModel, List[MegatronModel]] ) -> Dict[str, List[AdapterWeightConversionTask]]: """Construct adapter merge tasks keyed by their base parameter. @@ -449,6 +454,9 @@ def build_adapter_conversion_tasks( merged into that base weight. """ + self.hf_pretrained = hf_pretrained + self.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained + if not isinstance(megatron_model, list): megatron_model = [megatron_model] @@ -598,6 +606,7 @@ def materialize_adapter_weights(self, adapter_tasks: List[AdapterWeightConversio def stream_adapter_weights_megatron_to_hf( self, megatron_model: Union[MegatronModel, List[MegatronModel]], + hf_pretrained: HFPreTrained, cpu: bool = True, show_progress: bool = True, ) -> Iterable[HFWeightTuple]: @@ -610,7 +619,7 @@ def stream_adapter_weights_megatron_to_hf( megatron_model = [megatron_model] num_moe_experts = megatron_model[0].config.num_moe_experts - adapter_tasks_by_base = self.build_adapter_conversion_tasks(megatron_model) + adapter_tasks_by_base = self.build_adapter_conversion_tasks(hf_pretrained, megatron_model) adapter_tasks = list(itertools.chain.from_iterable(adapter_tasks_by_base.values())) if not adapter_tasks: return diff --git a/src/megatron/bridge/models/glm/glm45_bridge.py b/src/megatron/bridge/models/glm/glm45_bridge.py index cbb1ca9238..1f609fea8f 100644 --- a/src/megatron/bridge/models/glm/glm45_bridge.py +++ b/src/megatron/bridge/models/glm/glm45_bridge.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging from functools import partial import torch @@ -45,9 +44,6 @@ HAVE_TE = False -logger = logging.getLogger(__name__) - - @MegatronModelBridge.register_bridge(source=Glm4MoeForCausalLM, target=GPTModel, model_type="glm4_moe") class GLM45Bridge(MegatronModelBridge): """ @@ -101,14 +97,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider return provider - def build_conversion_tasks(self, hf_pretrained, megatron_model): - """Override to store config before mapping_registry is called.""" - # Store config on instance for use in mapping_registry - self._hf_config = hf_pretrained.config - self._hf_state_source = hf_pretrained.state.source - self._hf_keys = list(self._hf_state_source.get_all_keys()) - return super().build_conversion_tasks(hf_pretrained, megatron_model) - def mapping_registry(self) -> MegatronMappingRegistry: mapping_list = [] use_fused_experts = self._uses_fused_experts() @@ -206,23 +194,11 @@ def mapping_registry(self) -> MegatronMappingRegistry: ), ] ) - # optionally add MTP mappings - if not hasattr(self, "_hf_config"): - logger.warning("No HF config found, skipping MTP mappings.") - return MegatronMappingRegistry(*mapping_list) - hf_config = self._hf_config + # add MTP mappings + hf_config = self.hf_config num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0) num_transformer_layers = hf_config.num_hidden_layers for mtp_layer in range(num_mtp_layers): - for megatron_param, hf_param in layer_specific_mappings.items(): - megatron_param = ( - megatron_param.replace(".*", ".*.transformer_layer") - .replace("decoder", "mtp") - .replace(".*", f".{mtp_layer}") - ) - hf_param = hf_param.replace("layers.*", f"layers.{mtp_layer + num_transformer_layers}") - mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) - # MTP specific mappings mapping_list.extend( [ @@ -244,96 +220,113 @@ def mapping_registry(self) -> MegatronMappingRegistry: ), ] ) - # Special mappings that require parameter concatenation/transformation - mapping_list.extend( - [ - QKVMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.self_attention.linear_qkv.weight", - q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.weight", - k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.weight", - v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.weight", - ), - QKVMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.self_attention.linear_qkv.bias", - q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.bias", - k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.bias", - v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.bias", - ), - GatedMLPMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.mlp.linear_fc1.weight", - gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.gate.weight", - up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.up.weight", - ), - GatedMLPMapping( - megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.mlp.shared_experts.linear_fc1.weight", - gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.gate_proj.weight", - up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.up_proj.weight", - ), - ] - ) - if use_fused_experts: + + for layer_prefix in ("transformer_layer", "mtp_model_layer"): + for megatron_param, hf_param in layer_specific_mappings.items(): + megatron_param = ( + megatron_param.replace(".*", f".*.{layer_prefix}") + .replace("decoder", "mtp") + .replace(".*", f".{mtp_layer}") + ) + hf_param = hf_param.replace("layers.*", f"layers.{mtp_layer + num_transformer_layers}") + mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param)) + + # Special mappings that require parameter concatenation/transformation mapping_list.extend( [ - GLMExpertGateUpProjMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc1.weight*" - ), - hf_param=( - f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.gate_up_proj" - f"{gate_up_suffix}" - ), + QKVMapping( + megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.self_attention.linear_qkv.weight", + q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.weight", + k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.weight", + v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.weight", ), - GLMExpertDownProjMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc2.weight*" - ), - hf_param=( - f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.down_proj{down_suffix}" - ), + QKVMapping( + megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.self_attention.linear_qkv.bias", + q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.bias", + k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.bias", + v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.bias", ), - ] - ) - else: - mapping_list.extend( - [ GatedMLPMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc1.weight*" - ), - gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.gate_proj.weight", - up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.up_proj.weight", + megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.linear_fc1.weight", + gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.gate.weight", + up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.up.weight", ), - AutoMapping( - megatron_param=( - f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc2.weight*" - ), - hf_param=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.down_proj.weight", + GatedMLPMapping( + megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.shared_experts.linear_fc1.weight", + gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.gate_proj.weight", + up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.up_proj.weight", ), ] ) + if use_fused_experts: + mapping_list.extend( + [ + GLMExpertGateUpProjMapping( + megatron_param=( + f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc1.weight*" + ), + hf_param=( + f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.gate_up_proj" + f"{gate_up_suffix}" + ), + ), + GLMExpertDownProjMapping( + megatron_param=( + f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc2.weight*" + ), + hf_param=( + f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.down_proj{down_suffix}" + ), + ), + ] + ) + else: + mapping_list.extend( + [ + GatedMLPMapping( + megatron_param=( + f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc1.weight*" + ), + gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.gate_proj.weight", + up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.up_proj.weight", + ), + AutoMapping( + megatron_param=( + f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc2.weight*" + ), + hf_param=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.down_proj.weight", + ), + ] + ) return MegatronMappingRegistry(*mapping_list) + def _hf_source_and_keys(self): + """Return HF state source and cached key order for expert-mapping helpers.""" + hf_source = self.hf_pretrained.state.source + if getattr(self, "_cached_hf_state_source", None) is not hf_source: + self._cached_hf_state_source = hf_source + self._cached_hf_keys = hf_source.get_all_keys() + return hf_source, self._cached_hf_keys + def _uses_fused_experts(self) -> bool: - hf_keys = getattr(self, "_hf_keys", None) + hf_source, hf_keys = self._hf_source_and_keys() if hf_keys: if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any( "mlp.experts.down_proj" in key for key in hf_keys ): return True - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None: return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*") return False def _hf_expert_suffix(self, base_name: str) -> str: - hf_keys = getattr(self, "_hf_keys", None) or [] + hf_source, hf_keys = self._hf_source_and_keys() if any(f"{base_name}.weight" in key for key in hf_keys): return ".weight" - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"): return ".weight" @@ -351,7 +344,7 @@ def maybe_modify_converted_hf_weight( if not converted_weights_dict: return {} - num_experts = self._hf_config.n_routed_experts + num_experts = self.hf_config.n_routed_experts ep_size = parallel_state.get_expert_model_parallel_world_size() experts_per_rank = num_experts // ep_size diff --git a/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py b/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py index 2357444d98..202b3bf96f 100644 --- a/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py +++ b/src/megatron/bridge/models/glm_vl/glm_45v_bridge.py @@ -92,13 +92,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedVLM) -> GLM45VModelProvider: ) return provider - def build_conversion_tasks(self, hf_pretrained, megatron_model): - """Override to store config before mapping_registry is called.""" - self._hf_config = hf_pretrained.config - self._hf_state_source = hf_pretrained.state.source - self._hf_keys = list(self._hf_state_source.get_all_keys()) - return super().build_conversion_tasks(hf_pretrained, megatron_model) - @classmethod def get_hf_tokenizer_kwargs(cls) -> dict: """Return HuggingFace tokenizer kwargs specific to GLM 4.5V models. @@ -215,26 +208,32 @@ def mapping_registry(self) -> MegatronMappingRegistry: ) return MegatronMappingRegistry(*mapping_list) + def _hf_source_and_keys(self): + """Return HF state source and cached key order for expert-mapping helpers.""" + hf_source = self.hf_pretrained.state.source + if getattr(self, "_cached_hf_state_source", None) is not hf_source: + self._cached_hf_state_source = hf_source + self._cached_hf_keys = hf_source.get_all_keys() + return hf_source, self._cached_hf_keys + def _uses_fused_experts(self) -> bool: - hf_keys = getattr(self, "_hf_keys", None) + hf_source, hf_keys = self._hf_source_and_keys() if hf_keys: if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any( "mlp.experts.down_proj" in key for key in hf_keys ): return True - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None: return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*") return False def _hf_expert_suffix(self, base_name: str) -> str: - hf_keys = getattr(self, "_hf_keys", None) or [] + hf_source, hf_keys = self._hf_source_and_keys() if any(f"{base_name}.weight" in key for key in hf_keys): return ".weight" - hf_source = getattr(self, "_hf_state_source", None) if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"): return ".weight" @@ -252,7 +251,7 @@ def maybe_modify_converted_hf_weight( if not converted_weights_dict: return {} - text_config = getattr(self._hf_config, "text_config", self._hf_config) + text_config = getattr(self.hf_pretrained.config, "text_config", self.hf_pretrained.config) num_experts = text_config.n_routed_experts ep_size = parallel_state.get_expert_model_parallel_world_size() experts_per_rank = num_experts // ep_size diff --git a/tests/unit_tests/models/glm/test_glm45_bridge.py b/tests/unit_tests/models/glm/test_glm45_bridge.py index 6cd9d92cac..e047797cc0 100644 --- a/tests/unit_tests/models/glm/test_glm45_bridge.py +++ b/tests/unit_tests/models/glm/test_glm45_bridge.py @@ -122,6 +122,10 @@ def mock_pretrained_355b(self, glm45_355b_config): m = Mock(spec=PreTrainedCausalLM) m.config = cfg m.generation_config = Mock(spec=GenerationConfig) + m.state = Mock() + m.state.source = Mock() + m.state.source.get_all_keys.return_value = [] + m.state.source.has_glob.return_value = False return m @pytest.fixture @@ -135,6 +139,10 @@ def mock_pretrained_air_106b(self, glm45_air_106b_config): m = Mock(spec=PreTrainedCausalLM) m.config = cfg m.generation_config = Mock(spec=GenerationConfig) + m.state = Mock() + m.state.source = Mock() + m.state.source.get_all_keys.return_value = [] + m.state.source.has_glob.return_value = False return m def test_registration(self): @@ -198,9 +206,11 @@ def test_provider_bridge_maps_config_air_106b(self, mock_pretrained_air_106b): assert provider.bf16 is True assert provider.params_dtype == torch.bfloat16 - def test_mapping_registry_exists(self): + def test_mapping_registry_exists(self, mock_pretrained_355b): """Test that mapping registry is properly defined.""" bridge = GLM45Bridge() + bridge.hf_pretrained = mock_pretrained_355b + bridge.hf_config = mock_pretrained_355b.config registry = bridge.mapping_registry() # Verify registry has mappings diff --git a/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py b/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py index 348d94b79b..1cfa9f302a 100644 --- a/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py +++ b/tests/unit_tests/models/glm_vl/test_glm_45v_bridge.py @@ -83,13 +83,19 @@ def mock_hf_pretrained(mock_hf_config): """Create a mock HF pretrained VLM.""" pretrained = Mock(spec=PreTrainedVLM) pretrained.config = mock_hf_config + pretrained.state = Mock() + pretrained.state.source = Mock() + pretrained.state.source.get_all_keys.return_value = [] + pretrained.state.source.has_glob.return_value = False return pretrained @pytest.fixture -def glm_45v_bridge(): +def glm_45v_bridge(mock_hf_pretrained): """Create a GLM45VBridge instance.""" - return GLM45VBridge() + bridge = GLM45VBridge() + bridge.hf_pretrained = mock_hf_pretrained + return bridge class TestGLM45VBridgeInitialization: diff --git a/tests/unit_tests/models/test_auto_bridge.py b/tests/unit_tests/models/test_auto_bridge.py index d11482c074..ee9a196394 100644 --- a/tests/unit_tests/models/test_auto_bridge.py +++ b/tests/unit_tests/models/test_auto_bridge.py @@ -611,8 +611,7 @@ def test_export_hf_weights(self): mock_hf_model.config.architectures = ["LlamaForCausalLM"] mock_hf_model.config.auto_map = None - mock_megatron_model = [Mock()] - mock_megatron_model[0].module = None # No nested module + mock_megatron_model = [object()] # Mock the export process with patch( @@ -637,6 +636,51 @@ def test_export_hf_weights(self): assert weights[1][0] == "weight2" assert isinstance(weights[0][1], torch.Tensor) assert isinstance(weights[1][1], torch.Tensor) + mock_bridge_state.assert_called_once_with( + (mock_arch_class, mock_megatron_model[0]), + mock_megatron_model, + mock_hf_model, + cpu=True, + show_progress=True, + conversion_tasks=None, + merge_adapter_weights=True, + ) + + def test_export_adapter_weights(self): + """Test exporting adapter weights from Megatron to HF format.""" + mock_hf_model = Mock(spec=PreTrainedCausalLM) + mock_hf_model.config = Mock() + mock_hf_model.config.architectures = ["LlamaForCausalLM"] + mock_hf_model.config.auto_map = None + + mock_megatron_model = [object()] + + with patch( + "megatron.bridge.models.conversion.auto_bridge.model_bridge.stream_adapter_weights_megatron_to_hf" + ) as mock_stream_adapter_weights: + mock_weight_iter = [("adapter.weight", torch.randn(4, 4))] + mock_stream_adapter_weights.return_value = iter(mock_weight_iter) + + with patch("megatron.bridge.models.conversion.auto_bridge.transformers") as mock_transformers: + mock_arch_class = Mock() + mock_transformers.LlamaForCausalLM = mock_arch_class + + bridge = AutoBridge(mock_hf_model) + + with patch.object(AutoBridge, "_causal_lm_architecture", new_callable=PropertyMock) as mock_prop: + mock_prop.return_value = mock_arch_class + weights = list(bridge.export_adapter_weights(mock_megatron_model, cpu=False, show_progress=False)) + + assert len(weights) == 1 + assert weights[0][0] == "adapter.weight" + assert isinstance(weights[0][1], torch.Tensor) + mock_stream_adapter_weights.assert_called_once_with( + (mock_arch_class, mock_megatron_model[0]), + mock_megatron_model, + mock_hf_model, + cpu=False, + show_progress=False, + ) def test_get_causal_lm_architecture(self): """Test getting the CausalLM architecture class.""" diff --git a/tests/unit_tests/models/test_model_bridge_lora.py b/tests/unit_tests/models/test_model_bridge_lora.py index 6641f72c4c..53f5298825 100644 --- a/tests/unit_tests/models/test_model_bridge_lora.py +++ b/tests/unit_tests/models/test_model_bridge_lora.py @@ -414,7 +414,7 @@ def test_build_adapter_conversion_tasks(monkeypatch): ) monkeypatch.setattr(bridge, "mapping_registry", lambda: registry) - tasks_by_base = bridge.build_adapter_conversion_tasks([Mock()]) + tasks_by_base = bridge.build_adapter_conversion_tasks(SimpleNamespace(), [Mock()]) assert "decoder.layers.0.mlp.linear_fc1" in tasks_by_base tasks = tasks_by_base["decoder.layers.0.mlp.linear_fc1"] assert len(tasks) == 1 @@ -512,7 +512,14 @@ def test_stream_adapter_weights_megatron_to_hf(monkeypatch): ) megatron_model = [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))] - weights = list(bridge.stream_adapter_weights_megatron_to_hf(megatron_model, cpu=False, show_progress=False)) + weights = list( + bridge.stream_adapter_weights_megatron_to_hf( + megatron_model, + SimpleNamespace(), + cpu=False, + show_progress=False, + ) + ) assert len(weights) == 2 assert weights[0].param_name.endswith("lora_A.weight") assert weights[1].param_name.endswith("lora_B.weight") @@ -575,6 +582,7 @@ def test_stream_adapter_weights_megatron_to_hf_qkv(monkeypatch): weights = list( bridge.stream_adapter_weights_megatron_to_hf( [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))], + SimpleNamespace(), cpu=False, show_progress=False, ) @@ -648,6 +656,7 @@ def test_stream_adapter_weights_megatron_to_hf_fused_fc1(monkeypatch): weights = list( bridge.stream_adapter_weights_megatron_to_hf( [SimpleNamespace(config=SimpleNamespace(num_moe_experts=0))], + SimpleNamespace(), cpu=False, show_progress=False, )