@@ -290,6 +290,51 @@ def _not_encoded(inputs):
290290 assert isinstance (inputs , dict )
291291 return 'input_ids' not in inputs and 'input_embedding' not in inputs
292292
293+ def _postprocess_tensor_cp (self , tensor ):
294+ """All-gather and reconstruct full sequence from CP-split tensor.
295+
296+ Uses load-balanced split pattern: each CP rank holds chunks [rank] and
297+ [2*cp_size - rank - 1] from the original 2*cp_size chunks.
298+
299+ Only the current rank's slice retains the original tensor (and its
300+ gradient graph); other ranks' slices are plain copies. This means
301+ backward through the reconstructed tensor only produces gradients for
302+ the local chunk, naturally distributing the gradient across CP ranks
303+ without extra scaling.
304+
305+ Args:
306+ tensor: [batch_size, seq_len/cp_size] CP-split tensor
307+
308+ Returns:
309+ [batch_size, full_seq_len] reconstructed full tensor
310+ """
311+ from megatron .core import parallel_state as mpu
312+ cp_size = mpu .get_context_parallel_world_size ()
313+ if cp_size <= 1 :
314+ return tensor
315+
316+ cp_rank = mpu .get_context_parallel_rank ()
317+ cp_group = mpu .get_context_parallel_group ()
318+
319+ gathered = [torch .empty_like (tensor ) for _ in range (cp_size )]
320+ torch .distributed .all_gather (gathered , tensor .contiguous (), group = cp_group )
321+ gathered [cp_rank ] = tensor
322+
323+ batch_size = tensor .shape [0 ]
324+ seq_len_per_cp = tensor .shape [1 ]
325+ full_seq_len = seq_len_per_cp * cp_size
326+ chunk_len = full_seq_len // (2 * cp_size )
327+ half_len = seq_len_per_cp // 2
328+
329+ output = tensor .new_zeros (batch_size , full_seq_len )
330+ for j in range (cp_size ):
331+ o = gathered [j ]
332+ output [:, j * chunk_len :(j + 1 ) * chunk_len ] = o [:, :half_len ]
333+ reverse_idx = 2 * cp_size - j - 1
334+ output [:, reverse_idx * chunk_len :(reverse_idx + 1 ) * chunk_len ] = o [:, half_len :]
335+
336+ return output
337+
293338 @remote_function ()
294339 def forward (self , * , inputs : Union [InputFeature , List [InputFeature ], Trajectory , List [Trajectory ]], ** kwargs ):
295340 raise NotImplementedError ('Megatron only supports `forward_backward` and `forward_only`' )
@@ -420,13 +465,13 @@ def post_loss_function(output_tensor, inputs, logps):
420465 mb_idx = _mb_counter [0 ]
421466 _mb_counter [0 ] += 1
422467 current_kwargs = loss_extra_kwargs_per_mb [mb_idx % len (loss_extra_kwargs_per_mb )]
423- outputs = ModelOutput (logits = output_tensor )
468+ outputs = ModelOutput (logits = output_tensor , logps = logps )
424469 result = loss_instance (inputs , outputs , ** current_kwargs )
425470 losses = result ['loss' ]
426471 counts = result ['num_tokens' ]
427472 if not counts :
428473 counts = torch .tensor (1 , device = losses .device )
429- return self .strategy .gather_loss_for_cp (losses , counts , output_tensor , logps )
474+ return self .strategy .reduce_loss (losses , counts , output_tensor , logps )
430475
431476 # Define forward step function for Megatron
432477 # forward_step_func(data_iterator, model) -> (output_tensor, partial(loss_func))
@@ -435,11 +480,15 @@ def forward_step_func(data_iterator, model):
435480 labels = batch .pop ('labels' , None )
436481 output_tensor = model (** batch )
437482 batch ['labels' ] = labels
483+ logps = None
438484 if labels is not None :
439485 loss_mask = (labels != - 100 ).bool ()
440486 masked_labels = labels .clone ()
441487 masked_labels [~ loss_mask ] = 0
442488 logps = selective_log_softmax (output_tensor , masked_labels )
489+ if cp_size > 1 :
490+ logps = self ._postprocess_tensor_cp (logps )
491+ batch ['labels' ] = self ._postprocess_tensor_cp (labels )
443492 return output_tensor , partial (post_loss_function , inputs = batch , logps = logps )
444493
445494 # Get Megatron's forward-backward function
@@ -514,7 +563,7 @@ def forward_step_func(data_iterator, model):
514563 torch .distributed .all_reduce (loss , op = torch .distributed .ReduceOp .AVG , group = dp_cp_group )
515564
516565 optimizer_config .inputs = inputs
517- if len ({_logps .shape [1 ] for _logps in logps }) == 1 :
566+ if logps and len ({_logps .shape [1 ] for _logps in logps }) == 1 :
518567 logps = torch .cat (logps , dim = 0 )
519568 if isinstance (loss , torch .Tensor ):
520569 loss = loss .detach ().cpu ().float ().numpy ()
0 commit comments