Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch.nn.functional as F
from megatron.core.inference.contexts import BaseInferenceContext
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.core.transformer.utils import (ensure_metadata_has_dp_cp_group, make_sharded_tensors_for_checkpoint,
sharded_state_dict_default)
from typing import List, Optional

try:
Expand Down Expand Up @@ -312,3 +314,90 @@ def forward(
nvtx_range_pop(suffix='out_proj')

return out, out_bias

def sharded_state_dict(self, prefix='', sharded_offsets=(), metadata=None, tp_group=None):
"""Provide a sharded state dictionary for distributed checkpointing."""
from megatron.core.ssm.gated_delta_net import _split_tensor_factory

# Guard for cases metadata is not provided
metadata = ensure_metadata_has_dp_cp_group(metadata)

sharded_state_dict = {}
# Parameters
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
sharded_state_dict = make_sharded_tensors_for_checkpoint(
sharded_state_dict,
prefix,
tensor_parallel_layers_axis_map={
'A_log': 0,
'dt_bias': 0,
}, # parameters sharded across TP
sharded_offsets=sharded_offsets,
tp_group=(tp_group if tp_group is not None else self.pg_collection.tp),
dp_cp_group=metadata['dp_cp_group'],
)
# Submodules
tp_group = tp_group if tp_group is not None else self.pg_collection.tp
for name, module in self.named_children():
if name == 'conv1d':
# Add TP sharding for Conv1d
module_sd = module.state_dict(prefix='', keep_vars=True)
tp_sharding_map = {'weight': 0}
if self.conv_bias:
tp_sharding_map['bias'] = 0
module_sharded_sd = make_sharded_tensors_for_checkpoint(
module_sd,
f'{prefix}{name}.',
tp_sharding_map,
sharded_offsets,
tp_group=tp_group,
dp_cp_group=metadata['dp_cp_group'],
)
else:
module_sharded_sd = sharded_state_dict_default(
module, f'{prefix}{name}.', sharded_offsets, metadata, tp_group=tp_group)

sharded_state_dict.update(module_sharded_sd)

# At this point the TP sharding is correctly defined for each tensor, but some of the
# tensors must be additionally split into separate parts
in_proj_dim_local_tp = self.in_proj_dim // self.tp_size
assert sharded_state_dict[f'{prefix}in_proj.weight'].data.size(0) == in_proj_dim_local_tp, (
in_proj_dim_local_tp,
sharded_state_dict[f'{prefix}in_proj.weight'],
)

sharded_state_dict[f'{prefix}in_proj.weight'] = _split_tensor_factory(
sharded_state_dict[f'{prefix}in_proj.weight'],
[
self.qk_dim // self.tp_size,
self.qk_dim // self.tp_size,
self.v_dim // self.tp_size,
self.v_dim // self.tp_size,
self.num_value_heads // self.tp_size,
self.num_value_heads // self.tp_size,
],
['query', 'key', 'value', 'z', 'beta', 'alpha'],
0,
)

conv_layer_name_list = ['conv1d.weight']
assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight'])
if self.conv_bias:
conv_layer_name_list.append('conv1d.bias')
assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias'])
Comment on lines +384 to +390
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The attribute self.conv_dim_local_tp is used in assertions but is not defined within the class or the method. It should be defined as a local variable, similar to in_proj_dim_local_tp on line 364, representing the local dimension of the convolution layers after tensor parallel sharding. Based on the forward pass logic, this dimension is (2 * self.qk_dim + self.v_dim) // self.tp_size.

Suggested change
conv_layer_name_list = ['conv1d.weight']
assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight'])
if self.conv_bias:
conv_layer_name_list.append('conv1d.bias')
assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == self.conv_dim_local_tp), (
self.conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias'])
conv_dim_local_tp = (2 * self.qk_dim + self.v_dim) // self.tp_size
conv_layer_name_list = ['conv1d.weight']
assert (sharded_state_dict[f'{prefix}conv1d.weight'].data.size(0) == conv_dim_local_tp), (
conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.weight'])
if self.conv_bias:
conv_layer_name_list.append('conv1d.bias')
assert (sharded_state_dict[f'{prefix}conv1d.bias'].data.size(0) == conv_dim_local_tp), (
conv_dim_local_tp, sharded_state_dict[f'{prefix}conv1d.bias'])

for conv_layer_name in conv_layer_name_list:
sharded_state_dict[f'{prefix}{conv_layer_name}'] = _split_tensor_factory(
sharded_state_dict[f'{prefix}{conv_layer_name}'],
[
self.qk_dim // self.tp_size,
self.qk_dim // self.tp_size,
self.v_dim // self.tp_size,
],
['query', 'key', 'value'],
0,
)

return sharded_state_dict
Loading