diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 93b129a..a8b3abe 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -3,6 +3,8 @@ import torch.nn.functional as F from megatron.core.inference.contexts import BaseInferenceContext from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.transformer.utils import (ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint, + sharded_state_dict_default) from typing import List, Optional try: @@ -312,3 +314,90 @@ def forward( nvtx_range_pop(suffix='out_proj') return out, out_bias + + def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None, tp_group=None): + """Provide a sharded state dictionary for distributed checkpointing.""" + from megatron.core.ssm.gated_delta_net import _split_tensor_factory + + # Guard for cases metadata is not provided + metadata = ensure_metadata_has_dp_cp_group(metadata) + + sharded_state_dict = {} + # Parameters + self._save_to_state_dict(sharded_state_dict, '', keep_vars=True) + sharded_state_dict = make_sharded_tensors_for_checkpoint( + sharded_state_dict, + prefix, + tensor_parallel_layers_axis_map={ + 'A_log': 0, + 'dt_bias': 0, + }, # parameters sharded across TP + sharded_offsets=sharded_offsets, + tp_group=(tp_group if tp_group is not None else self.pg_collection.tp), + dp_cp_group=metadata['dp_cp_group'], + ) + # Submodules + tp_group = tp_group if tp_group is not None else self.pg_collection.tp + for name, module in self.named_children(): + if name == 'conv1d': + # Add TP sharding for Conv1d + module_sd = module.state_dict(prefix='', keep_vars=True) + tp_sharding_map = {'weight': 0} + if self.conv_bias: + tp_sharding_map['bias'] = 0 + module_sharded_sd = make_sharded_tensors_for_checkpoint( + module_sd, + f'{prefix}{name}.', + tp_sharding_map, + sharded_offsets, + tp_group=tp_group, + dp_cp_group=metadata['dp_cp_group'], + ) + else: + module_sharded_sd = sharded_state_dict_default( + module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=tp_group) + + sharded_state_dict.update(module_sharded_sd) + + # At this point the TP sharding is correctly defined for each tensor, but some of the + # tensors must be additionally split into separate parts + in_proj_dim_local_tp = self.in_proj_dim // self.tp_size + assert sharded_state_dict[f'{prefix}in_proj.weight'].data.size(0) == in_proj_dim_local_tp, ( + in_proj_dim_local_tp, + sharded_state_dict[f'{prefix}in_proj.weight'], + ) + + sharded_state_dict[f'{prefix}in_proj.weight'] = _split_tensor_factory( + sharded_state_dict[f'{prefix}in_proj.weight'], + [ + self.qk_dim // self.tp_size, + self.qk_dim // self.tp_size, + self.v_dim // self.tp_size, + self.v_dim // self.tp_size, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + ['query', 'key', 'value', 'z', 'beta', 'alpha'], + 0, + ) + + conv_layer_name_list = ['conv1d.weight'] + assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), ( + self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight']) + if self.conv_bias: + conv_layer_name_list.append('conv1d.bias') + assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), ( + self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias']) + for conv_layer_name in conv_layer_name_list: + sharded_state_dict[f'{prefix}{conv_layer_name}'] = _split_tensor_factory( + sharded_state_dict[f'{prefix}{conv_layer_name}'], + [ + self.qk_dim // self.tp_size, + self.qk_dim // self.tp_size, + self.v_dim // self.tp_size, + ], + ['query', 'key', 'value'], + 0, + ) + + return sharded_state_dict