Skip to content

Commit 27066a2

Browse files
authored
[feat] Support multimodel mtp (#14)
1 parent 085a837 commit 27066a2

3 files changed

Lines changed: 69 additions & 6 deletions

File tree

src/mcore_bridge/config/model_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,9 +205,7 @@ class ModelConfig(TransformerConfig):
205205

206206
# visual
207207
hf_config: Optional[PretrainedConfig] = None
208-
vit_gradient_checkpointing: Optional[bool] = None
209208
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'
210-
vit_gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None
211209

212210
# Override
213211
perform_initialization: bool = False

src/mcore_bridge/model/gpt_model.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ def forward(
322322
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
323323
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]
324324

325+
mtp_decoder_input = decoder_input
326+
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
327+
input_tensor = self.get_input_tensor()
328+
input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0)
329+
self.set_input_tensor(input_tensor)
325330
# Run decoder.
326331
hidden_states = self.decoder(
327332
hidden_states=decoder_input,
@@ -346,7 +351,7 @@ def forward(
346351
rotary_pos_cos=rotary_pos_cos,
347352
rotary_pos_sin=rotary_pos_sin,
348353
loss_mask=loss_mask,
349-
decoder_input=decoder_input,
354+
decoder_input=mtp_decoder_input,
350355
attention_mask=attention_mask,
351356
inference_params=inference_params,
352357
packed_seq_params=packed_seq_params,
@@ -381,7 +386,10 @@ def _postprocess(
381386
the output layer, and computes language model loss when labels are provided.
382387
"""
383388
if not self.post_process:
384-
return hidden_states
389+
if self.config.is_multimodal and self.config.mtp_num_layers:
390+
return torch.concat([hidden_states, decoder_input], dim=0)
391+
else:
392+
return hidden_states
385393
labels = labels if self.config.task_type == 'causal_lm' else None
386394
in_inference_mode = inference_context is not None and not self.training
387395
if in_inference_mode:
@@ -395,6 +403,10 @@ def _postprocess(
395403
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)
396404

397405
if self.mtp_process:
406+
if self.config.is_multimodal:
407+
embedding_ = (self.embedding, decoder_input)
408+
else:
409+
embedding_ = self.embedding
398410
hidden_states = self.mtp(
399411
input_ids=input_ids,
400412
position_ids=position_ids,
@@ -406,7 +418,7 @@ def _postprocess(
406418
rotary_pos_sin=rotary_pos_sin,
407419
packed_seq_params=packed_seq_params,
408420
sequence_len_offset=sequence_len_offset,
409-
embedding=self.embedding,
421+
embedding=embedding_,
410422
**(extra_block_kwargs or {}),
411423
)
412424
mtp_labels = labels.clone()

src/mcore_bridge/patcher.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from packaging import version
2222
from peft.tuners.tuners_utils import BaseTuner
2323
from torch import nn
24-
from typing import List, Optional, Tuple
24+
from typing import Callable, List, Optional, Tuple
2525

2626
from mcore_bridge.utils import get_logger, is_flash_attn_3_available
2727

@@ -471,6 +471,59 @@ def forward(
471471

472472
MultiTokenPredictionLayer.forward = forward
473473

474+
def _get_embeddings(
475+
self,
476+
input_ids: torch.Tensor,
477+
position_ids: torch.Tensor,
478+
embedding: Callable,
479+
hidden_states: torch.Tensor,
480+
packed_seq_params: Optional[PackedSeqParams] = None,
481+
):
482+
from megatron.core.transformer.multi_token_prediction import roll_tensor
483+
from megatron.core.utils import make_viewless_tensor
484+
485+
# Calc logits for the current Multi-Token Prediction (MTP) layers.
486+
input_ids, _ = roll_tensor(
487+
input_ids,
488+
shifts=-1,
489+
dims=-1,
490+
cp_group=self.cp_group,
491+
packed_seq_params=packed_seq_params,
492+
)
493+
position_ids, _ = roll_tensor(
494+
position_ids,
495+
shifts=-1,
496+
dims=-1,
497+
cp_group=self.cp_group,
498+
packed_seq_params=packed_seq_params,
499+
)
500+
# embedding
501+
if isinstance(embedding, tuple):
502+
embedding, decoder_input = embedding
503+
else:
504+
decoder_input = None
505+
if decoder_input is None:
506+
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
507+
else:
508+
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
509+
if enable_sp:
510+
decoder_input = gather_from_sequence_parallel_region(decoder_input)
511+
decoder_input, _ = roll_tensor(
512+
decoder_input.transpose(0, 2),
513+
shifts=-1,
514+
dims=-1,
515+
cp_group=self.cp_group,
516+
packed_seq_params=packed_seq_params,
517+
)
518+
decoder_input = decoder_input.transpose(0, 2).contiguous()
519+
if enable_sp:
520+
decoder_input = scatter_to_sequence_parallel_region(decoder_input)
521+
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)
522+
523+
return input_ids, position_ids, decoder_input, hidden_states
524+
525+
MultiTokenPredictionLayer._get_embeddings = _get_embeddings
526+
474527

475528
def _patch_peft_ModulesToSaveWrapper():
476529
if version.parse(peft.__version__) >= version.parse('0.16'):

0 commit comments

Comments
 (0)