@@ -158,6 +158,7 @@ def _apply_rotary_pos_emb_bshd(
158158 rotary_interleaved : bool = False ,
159159 multi_latent_attention : bool = False , # not use
160160 mscale : float = 1.0 ,
161+ ** kwargs ,
161162 ) -> torch .Tensor :
162163 """Apply rotary positional embedding to input tensor T.
163164
@@ -390,6 +391,8 @@ def _postprocess(
390391 output_weight = None
391392 if self .share_embeddings_and_output_weights :
392393 output_weight = self .shared_embedding_or_output_weight ()
394+ if self .config .is_multimodal and self .config .context_parallel_size > 1 :
395+ input_ids = split_cp_inputs (input_ids , getattr (packed_seq_params , 'cu_seqlens_q' , None ), 1 )
393396
394397 if self .mtp_process :
395398 hidden_states = self .mtp (
@@ -406,55 +409,52 @@ def _postprocess(
406409 embedding = self .embedding ,
407410 ** (extra_block_kwargs or {}),
408411 )
412+ mtp_labels = labels .clone ()
409413 hidden_states_list = torch .chunk (hidden_states , 1 + self .config .mtp_num_layers , dim = 0 )
410414 hidden_states = hidden_states_list [0 ]
411-
412- if labels is not None :
413- mtp_labels = labels .clone ()
414- if loss_mask is None :
415- # if loss_mask is not provided, use all ones as loss_mask
416- if packed_seq_params is None :
417- loss_mask = torch .ones_like (mtp_labels )
418- else :
419- loss_mask = mtp_labels .new_ones ((1 , packed_seq_params .cu_seqlens_q [- 1 ]))
420- cu_seqlens = packed_seq_params .cu_seqlens_q if packed_seq_params is not None else None
421- for mtp_layer_number in range (self .config .mtp_num_layers ):
422- # output
423- mtp_logits , _ = self .output_layer (
424- hidden_states_list [mtp_layer_number + 1 ],
425- weight = output_weight ,
426- runtime_gather_output = runtime_gather_output ,
415+ if loss_mask is None :
416+ # if loss_mask is not provided, use all ones as loss_mask
417+ loss_mask = torch .ones_like (mtp_labels )
418+ for mtp_layer_number in range (self .config .mtp_num_layers ):
419+ # output
420+ mtp_logits , _ = self .output_layer (
421+ hidden_states_list [mtp_layer_number + 1 ],
422+ weight = output_weight ,
423+ runtime_gather_output = runtime_gather_output ,
424+ )
425+ # Calc loss for the current Multi-Token Prediction (MTP) layers.
426+ mtp_labels , _ = roll_tensor (
427+ mtp_labels ,
428+ shifts = - 1 ,
429+ dims = - 1 ,
430+ cp_group = self .cp_group ,
431+ packed_seq_params = packed_seq_params ,
432+ )
433+ loss_mask , _ = roll_tensor (
434+ loss_mask ,
435+ shifts = - 1 ,
436+ dims = - 1 ,
437+ cp_group = self .cp_group ,
438+ packed_seq_params = packed_seq_params ,
439+ )
440+ mtp_loss = self .compute_language_model_loss (mtp_labels , mtp_logits )
441+ loss_mask_ = (loss_mask & (mtp_labels != - 100 ))
442+ num_tokens = loss_mask_ .sum ()
443+ mtp_loss = loss_mask_ * mtp_loss
444+ if self .training :
445+ mtp_loss_for_log = (
446+ torch .sum (mtp_loss ) / num_tokens if num_tokens > 0 else mtp_loss .new_tensor (0.0 ))
447+ MTPLossLoggingHelper .save_loss_to_tracker (
448+ mtp_loss_for_log ,
449+ mtp_layer_number ,
450+ self .config .mtp_num_layers ,
451+ avg_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
427452 )
428- # Calc loss for the current Multi-Token Prediction (MTP) layers.
429- mtp_labels , _ = roll_tensor (mtp_labels , shifts = - 1 , dims = - 1 , cp_group = self .cp_group )
430- if cu_seqlens is None :
431- loss_mask , _ = roll_tensor (loss_mask , shifts = - 1 , dims = - 1 , cp_group = self .cp_group )
432- loss_mask_ = loss_mask
433- else :
434- loss_mask [:, cu_seqlens [:- 1 ]] = 0
435- loss_mask , _ = roll_tensor (loss_mask , shifts = - 1 , dims = - 1 )
436- if self .config .context_parallel_size > 1 :
437- loss_mask_ = split_cp_inputs (loss_mask , cu_seqlens , dim = 1 )
438- else :
439- loss_mask_ = loss_mask .clone ()
440- mtp_loss = self .compute_language_model_loss (mtp_labels , mtp_logits )
441- loss_mask_ = loss_mask_ & (mtp_labels != - 100 )
442- mtp_loss = loss_mask_ * mtp_loss
443- num_tokens = loss_mask_ .sum ()
444- if self .training :
445- mtp_loss_for_log = (
446- torch .sum (mtp_loss ) / num_tokens if num_tokens > 0 else mtp_loss .new_tensor (0.0 ))
447- MTPLossLoggingHelper .save_loss_to_tracker (
448- mtp_loss_for_log ,
449- mtp_layer_number ,
450- self .config .mtp_num_layers ,
451- avg_group = parallel_state .get_data_parallel_group (with_context_parallel = True ),
452- )
453- mtp_loss_scale = self .config .mtp_loss_scaling_factor / self .config .mtp_num_layers
454- if self .config .calculate_per_token_loss :
455- hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss )
456- else :
457- hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss / num_tokens )
453+ mtp_loss_scale = self .config .mtp_loss_scaling_factor / self .config .mtp_num_layers
454+ if self .config .calculate_per_token_loss :
455+ hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss )
456+ else :
457+ hidden_states = MTPLossAutoScaler .apply (hidden_states , mtp_loss_scale * mtp_loss / num_tokens )
458458 sequence_parallel_override = False
459459 if in_inference_mode and inference_context .materialize_only_last_token_logits :
460460 if inference_context .is_static_batching ():
0 commit comments