Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/megatron/bridge/models/conversion/auto_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,7 @@ def export_adapter_weights(
return model_bridge.stream_adapter_weights_megatron_to_hf(
dispatch_instance,
model,
self.hf_pretrained,
cpu=cpu,
show_progress=show_progress,
)
Expand Down
16 changes: 14 additions & 2 deletions src/megatron/bridge/models/conversion/model_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,7 @@ def stream_weights_megatron_to_hf(
# Collect adapter conversion tasks when merge is requested
adapter_tasks_by_base: Dict[str, List[AdapterWeightConversionTask]] = {}
if merge_adapter_weights:
adapter_tasks_by_base = self.build_adapter_conversion_tasks(megatron_model)
adapter_tasks_by_base = self.build_adapter_conversion_tasks(hf_pretrained, megatron_model)

megatron_to_hf_tasks = conversion_tasks
unwrapped_model = unwrap_model(megatron_model)[0]
Expand Down Expand Up @@ -1205,6 +1205,9 @@ def build_conversion_tasks(
if not (hasattr(hf_pretrained, "state") and hasattr(hf_pretrained.state, "source")):
raise ValueError("hf_pretrained.state.source is required for weight ordering")

self.hf_pretrained = hf_pretrained
self.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained

hf_keys: Iterable[str] = hf_pretrained.state.source.get_all_keys()

mapping_registry = self.mapping_registry()
Expand Down Expand Up @@ -1394,6 +1397,7 @@ def stream_weights_megatron_to_hf(
def stream_adapter_weights_megatron_to_hf(
dispatch_instance: MegatronModel,
megatron_model: Union[MegatronModel, List[MegatronModel]],
hf_pretrained: HFPreTrained,
cpu: bool = True,
show_progress: bool = True,
) -> Iterable[HFWeightTuple]:
Expand Down Expand Up @@ -1437,7 +1441,8 @@ def _megatron_to_hf_registered_impl(
) -> Iterable[HFWeightTuple]:
bridge = bridge_class()

# allow bridge to access model config (config-only shims or raw configs lack .config)
# allow bridge to access model config
bridge.hf_pretrained = hf_pretrained
bridge.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained

return bridge.stream_weights_megatron_to_hf(
Expand All @@ -1453,12 +1458,19 @@ def _megatron_to_hf_registered_impl(
def _adapter_stream_registered_impl(
_,
megatron_model: Union[MegatronModel, List[MegatronModel]],
hf_pretrained: HFPreTrained,
cpu: bool = True,
show_progress: bool = True,
) -> Iterable[HFWeightTuple]:
bridge = bridge_class()

# allow bridge to access model config
bridge.hf_pretrained = hf_pretrained
bridge.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained

return bridge.stream_adapter_weights_megatron_to_hf(
megatron_model,
hf_pretrained,
cpu=cpu,
show_progress=show_progress,
)
Expand Down
15 changes: 12 additions & 3 deletions src/megatron/bridge/models/conversion/peft_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@

if TYPE_CHECKING:
from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry
from megatron.bridge.models.conversion.model_bridge import HFWeightTuple, MegatronWeightTuple, WeightConversionTask
from megatron.bridge.models.conversion.model_bridge import (
HFPreTrained,
HFWeightTuple,
MegatronWeightTuple,
WeightConversionTask,
)


MegatronModel = TypeVar("MegatronModel", bound=MegatronModule)
Expand Down Expand Up @@ -439,7 +444,7 @@ def _construct_adapters_names(self, prefix: str, adapter_key: Optional[str]) ->
return linear_in_name, linear_out_name

def build_adapter_conversion_tasks(
self, megatron_model: Union[MegatronModel, List[MegatronModel]]
self, hf_pretrained: HFPreTrained, megatron_model: Union[MegatronModel, List[MegatronModel]]
) -> Dict[str, List[AdapterWeightConversionTask]]:
"""Construct adapter merge tasks keyed by their base parameter.

Expand All @@ -449,6 +454,9 @@ def build_adapter_conversion_tasks(
merged into that base weight.
"""

self.hf_pretrained = hf_pretrained
self.hf_config = hf_pretrained.config if hasattr(hf_pretrained, "config") else hf_pretrained

if not isinstance(megatron_model, list):
megatron_model = [megatron_model]

Expand Down Expand Up @@ -598,6 +606,7 @@ def materialize_adapter_weights(self, adapter_tasks: List[AdapterWeightConversio
def stream_adapter_weights_megatron_to_hf(
self,
megatron_model: Union[MegatronModel, List[MegatronModel]],
hf_pretrained: HFPreTrained,
cpu: bool = True,
show_progress: bool = True,
) -> Iterable[HFWeightTuple]:
Expand All @@ -610,7 +619,7 @@ def stream_adapter_weights_megatron_to_hf(
megatron_model = [megatron_model]

num_moe_experts = megatron_model[0].config.num_moe_experts
adapter_tasks_by_base = self.build_adapter_conversion_tasks(megatron_model)
adapter_tasks_by_base = self.build_adapter_conversion_tasks(hf_pretrained, megatron_model)
adapter_tasks = list(itertools.chain.from_iterable(adapter_tasks_by_base.values()))
if not adapter_tasks:
return
Expand Down
171 changes: 82 additions & 89 deletions src/megatron/bridge/models/glm/glm45_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from functools import partial

import torch
Expand Down Expand Up @@ -45,9 +44,6 @@
HAVE_TE = False


logger = logging.getLogger(__name__)


@MegatronModelBridge.register_bridge(source=Glm4MoeForCausalLM, target=GPTModel, model_type="glm4_moe")
class GLM45Bridge(MegatronModelBridge):
"""
Expand Down Expand Up @@ -101,14 +97,6 @@ def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> GPTModelProvider

return provider

def build_conversion_tasks(self, hf_pretrained, megatron_model):
"""Override to store config before mapping_registry is called."""
# Store config on instance for use in mapping_registry
self._hf_config = hf_pretrained.config
self._hf_state_source = hf_pretrained.state.source
self._hf_keys = list(self._hf_state_source.get_all_keys())
return super().build_conversion_tasks(hf_pretrained, megatron_model)

def mapping_registry(self) -> MegatronMappingRegistry:
mapping_list = []
use_fused_experts = self._uses_fused_experts()
Expand Down Expand Up @@ -206,23 +194,11 @@ def mapping_registry(self) -> MegatronMappingRegistry:
),
]
)
# optionally add MTP mappings
if not hasattr(self, "_hf_config"):
logger.warning("No HF config found, skipping MTP mappings.")
return MegatronMappingRegistry(*mapping_list)
hf_config = self._hf_config
# add MTP mappings
hf_config = self.hf_config
num_mtp_layers = getattr(hf_config, "num_nextn_predict_layers", 0)
num_transformer_layers = hf_config.num_hidden_layers
for mtp_layer in range(num_mtp_layers):
for megatron_param, hf_param in layer_specific_mappings.items():
megatron_param = (
megatron_param.replace(".*", ".*.transformer_layer")
.replace("decoder", "mtp")
.replace(".*", f".{mtp_layer}")
)
hf_param = hf_param.replace("layers.*", f"layers.{mtp_layer + num_transformer_layers}")
mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param))

# MTP specific mappings
mapping_list.extend(
[
Expand All @@ -244,96 +220,113 @@ def mapping_registry(self) -> MegatronMappingRegistry:
),
]
)
# Special mappings that require parameter concatenation/transformation
mapping_list.extend(
[
QKVMapping(
megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.self_attention.linear_qkv.weight",
q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.weight",
k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.weight",
v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.weight",
),
QKVMapping(
megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.self_attention.linear_qkv.bias",
q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.bias",
k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.bias",
v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.bias",
),
GatedMLPMapping(
megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.mlp.linear_fc1.weight",
gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.gate.weight",
up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.up.weight",
),
GatedMLPMapping(
megatron_param=f"mtp.layers.{mtp_layer}.transformer_layer.mlp.shared_experts.linear_fc1.weight",
gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.gate_proj.weight",
up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.up_proj.weight",
),
]
)
if use_fused_experts:

for layer_prefix in ("transformer_layer", "mtp_model_layer"):
for megatron_param, hf_param in layer_specific_mappings.items():
megatron_param = (
megatron_param.replace(".*", f".*.{layer_prefix}")
.replace("decoder", "mtp")
.replace(".*", f".{mtp_layer}")
)
hf_param = hf_param.replace("layers.*", f"layers.{mtp_layer + num_transformer_layers}")
mapping_list.append(AutoMapping(megatron_param=megatron_param, hf_param=hf_param))

# Special mappings that require parameter concatenation/transformation
mapping_list.extend(
[
GLMExpertGateUpProjMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc1.weight*"
),
hf_param=(
f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.gate_up_proj"
f"{gate_up_suffix}"
),
QKVMapping(
megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.self_attention.linear_qkv.weight",
q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.weight",
k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.weight",
v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.weight",
),
GLMExpertDownProjMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc2.weight*"
),
hf_param=(
f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.down_proj{down_suffix}"
),
QKVMapping(
megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.self_attention.linear_qkv.bias",
q=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.q_proj.bias",
k=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.k_proj.bias",
v=f"model.layers.{mtp_layer + num_transformer_layers}.self_attn.v_proj.bias",
),
]
)
else:
mapping_list.extend(
[
GatedMLPMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc1.weight*"
),
gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.gate_proj.weight",
up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.up_proj.weight",
megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.linear_fc1.weight",
gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.gate.weight",
up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.linear_fc1.up.weight",
),
AutoMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.transformer_layer.mlp.experts.linear_fc2.weight*"
),
hf_param=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.down_proj.weight",
GatedMLPMapping(
megatron_param=f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.shared_experts.linear_fc1.weight",
gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.gate_proj.weight",
up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.shared_experts.up_proj.weight",
),
]
)
if use_fused_experts:
mapping_list.extend(
[
GLMExpertGateUpProjMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc1.weight*"
),
hf_param=(
f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.gate_up_proj"
f"{gate_up_suffix}"
),
),
GLMExpertDownProjMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc2.weight*"
),
hf_param=(
f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.down_proj{down_suffix}"
),
),
]
)
else:
mapping_list.extend(
[
GatedMLPMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc1.weight*"
),
gate=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.gate_proj.weight",
up=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.up_proj.weight",
),
AutoMapping(
megatron_param=(
f"mtp.layers.{mtp_layer}.{layer_prefix}.mlp.experts.linear_fc2.weight*"
),
hf_param=f"model.layers.{mtp_layer + num_transformer_layers}.mlp.experts.*.down_proj.weight",
),
]
)

return MegatronMappingRegistry(*mapping_list)

def _hf_source_and_keys(self):
"""Return HF state source and cached key order for expert-mapping helpers."""
hf_source = self.hf_pretrained.state.source
if getattr(self, "_cached_hf_state_source", None) is not hf_source:
self._cached_hf_state_source = hf_source
self._cached_hf_keys = hf_source.get_all_keys()
return hf_source, self._cached_hf_keys

def _uses_fused_experts(self) -> bool:
hf_keys = getattr(self, "_hf_keys", None)
hf_source, hf_keys = self._hf_source_and_keys()
if hf_keys:
if any("mlp.experts.gate_up_proj" in key for key in hf_keys) or any(
"mlp.experts.down_proj" in key for key in hf_keys
):
return True

hf_source = getattr(self, "_hf_state_source", None)
if hf_source is not None:
return hf_source.has_glob("*mlp.experts.gate_up_proj*") or hf_source.has_glob("*mlp.experts.down_proj*")

return False

def _hf_expert_suffix(self, base_name: str) -> str:
hf_keys = getattr(self, "_hf_keys", None) or []
hf_source, hf_keys = self._hf_source_and_keys()
if any(f"{base_name}.weight" in key for key in hf_keys):
return ".weight"

hf_source = getattr(self, "_hf_state_source", None)
if hf_source is not None and hf_source.has_glob(f"*{base_name}.weight"):
return ".weight"

Expand All @@ -351,7 +344,7 @@ def maybe_modify_converted_hf_weight(
if not converted_weights_dict:
return {}

num_experts = self._hf_config.n_routed_experts
num_experts = self.hf_config.n_routed_experts
ep_size = parallel_state.get_expert_model_parallel_world_size()
experts_per_rank = num_experts // ep_size

Expand Down
Loading