@@ -322,6 +322,11 @@ def forward(
322322 assert position_ids .shape [0 ] == 1 , f'position_ids.shape: { position_ids .shape } '
323323 decoder_rotary_pos_emb = rotary_pos_emb [position_ids [0 ]]
324324
325+ mtp_decoder_input = decoder_input
326+ if self .config .is_multimodal and self .config .mtp_num_layers and decoder_input is None :
327+ input_tensor = self .get_input_tensor ()
328+ input_tensor , mtp_decoder_input = input_tensor .chunk (2 , dim = 0 )
329+ self .set_input_tensor (input_tensor )
325330 # Run decoder.
326331 hidden_states = self .decoder (
327332 hidden_states = decoder_input ,
@@ -346,7 +351,7 @@ def forward(
346351 rotary_pos_cos = rotary_pos_cos ,
347352 rotary_pos_sin = rotary_pos_sin ,
348353 loss_mask = loss_mask ,
349- decoder_input = decoder_input ,
354+ decoder_input = mtp_decoder_input ,
350355 attention_mask = attention_mask ,
351356 inference_params = inference_params ,
352357 packed_seq_params = packed_seq_params ,
@@ -381,7 +386,10 @@ def _postprocess(
381386 the output layer, and computes language model loss when labels are provided.
382387 """
383388 if not self .post_process :
384- return hidden_states
389+ if self .config .is_multimodal and self .config .mtp_num_layers :
390+ return torch .concat ([hidden_states , decoder_input ], dim = 0 )
391+ else :
392+ return hidden_states
385393 labels = labels if self .config .task_type == 'causal_lm' else None
386394 in_inference_mode = inference_context is not None and not self .training
387395 if in_inference_mode :
@@ -395,6 +403,10 @@ def _postprocess(
395403 input_ids = split_cp_inputs (input_ids , getattr (packed_seq_params , 'cu_seqlens_q' , None ), 1 )
396404
397405 if self .mtp_process :
406+ if self .config .is_multimodal :
407+ embedding_ = (self .embedding , decoder_input )
408+ else :
409+ embedding_ = self .embedding
398410 hidden_states = self .mtp (
399411 input_ids = input_ids ,
400412 position_ids = position_ids ,
@@ -406,7 +418,7 @@ def _postprocess(
406418 rotary_pos_sin = rotary_pos_sin ,
407419 packed_seq_params = packed_seq_params ,
408420 sequence_len_offset = sequence_len_offset ,
409- embedding = self . embedding ,
421+ embedding = embedding_ ,
410422 ** (extra_block_kwargs or {}),
411423 )
412424 mtp_labels = labels .clone ()
0 commit comments