From 8de072a4fb1a7439c4003b3e83706873d9815aa6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Apr 2026 13:38:32 +0800 Subject: [PATCH 1/2] fix gdn sharded_state_dict lora --- .../model/modules/gated_delta_net.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 93b129a..27a8bc2 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -312,3 +312,88 @@ def forward( nvtx_range_pop(suffix='out_proj') return out, out_bias + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None, tp_group=None): + """Provide a sharded state dictionary for distributed checkpointing.""" + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'A_log': 0, + 'dt_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + tp_group=(tp_group if tp_group is not None else self.pg_collection.tp), + dp_cp_group=metadata['dp_cp_group'], + ) + # Submodules + tp_group = tp_group if tp_group is not None else self.pg_collection.tp + for name, module in self.named_children(): + if name == 'conv1d': + # Add TP sharding for Conv1d + module_sd = module.state_dict(prefix='', keep_vars=True) + tp_sharding_map = {'weight': 0} + if self.conv_bias: + tp_sharding_map['bias'] = 0 + module_sharded_sd = make_sharded_tensors_for_checkpoint( + module_sd, + f'{prefix}{name}.', + tp_sharding_map, + sharded_offsets, + tp_group=tp_group, + dp_cp_group=metadata['dp_cp_group'], + ) + else: + module_sharded_sd = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=tp_group) + + sharded_state_dict.update(module_sharded_sd) + + # At this point the TP sharding is correctly defined for each tensor, but some of the + # tensors must be additionally split into separate parts + in_proj_dim_local_tp = self.in_proj_dim // self.tp_size + assert sharded_state_dict[f'{prefix}in_proj.weight'].data.size(0) == in_proj_dim_local_tp, ( + in_proj_dim_local_tp, + sharded_state_dict[f'{prefix}in_proj.weight'], + ) + + sharded_state_dict[f'{prefix}in_proj.weight'] = _split_tensor_factory( + sharded_state_dict[f'{prefix}in_proj.weight'], + [ + self.qk_dim // self.tp_size, + self.qk_dim // self.tp_size, + self.v_dim // self.tp_size, + self.v_dim // self.tp_size, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + ['query', 'key', 'value', 'z', 'beta', 'alpha'], + 0, + ) + + conv_layer_name_list = ['conv1d.weight'] + assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), ( + self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight']) + if self.conv_bias: + conv_layer_name_list.append('conv1d.bias') + assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), ( + self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias']) + for conv_layer_name in conv_layer_name_list: + sharded_state_dict[f'{prefix}{conv_layer_name}'] = _split_tensor_factory( + sharded_state_dict[f'{prefix}{conv_layer_name}'], + [ + self.qk_dim // self.tp_size, + self.qk_dim // self.tp_size, + self.v_dim // self.tp_size, + ], + ['query', 'key', 'value'], + 0, + ) + + return sharded_state_dict From ed2b0e652dbd45c006e95908f7bb7f599b51b196 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 9 Apr 2026 13:41:21 +0800 Subject: [PATCH 2/2] update --- src/mcore_bridge/model/modules/gated_delta_net.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 27a8bc2..a8b3abe 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.utils import (ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint, + sharded_state_dict_default) from typing import List, Optional try: @@ -315,6 +317,8 @@ def forward( def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None, tp_group=None): """Provide a sharded state dictionary for distributed checkpointing.""" + from megatron.core.ssm.gated_delta_net import _split_tensor_factory + # Guard for cases metadata is not provided metadata = ensure_metadata_has_dp_cp_group(metadata)