@@ -250,6 +250,38 @@ def _not_encoded(inputs):
250250 assert isinstance (inputs , dict )
251251 return 'input_ids' not in inputs and 'input_embedding' not in inputs
252252
253+ @staticmethod
254+ def _slice_value_for_microbatch (value , mb_start : int , mb_end : int , micro_batch_size : int ):
255+ """Recursively slice a value for microbatch processing.
256+
257+ Handles nested dicts (e.g., ref_outputs: {"logps": tensor}) by recursively
258+ slicing internal tensors.
259+
260+ Args:
261+ value: The value to slice (tensor, ndarray, list, dict, or scalar)
262+ mb_start: Start index of the microbatch
263+ mb_end: End index of the microbatch
264+ micro_batch_size: Size of each microbatch
265+
266+ Returns:
267+ Sliced value with the same structure
268+ """
269+ if isinstance (value , torch .Tensor ) and value .dim () >= 1 and value .shape [0 ] > micro_batch_size :
270+ return value [mb_start :mb_end ]
271+ elif isinstance (value , np .ndarray ) and value .ndim >= 1 and value .shape [0 ] > micro_batch_size :
272+ return value [mb_start :mb_end ]
273+ elif isinstance (value , (list , tuple )) and len (value ) > micro_batch_size :
274+ return value [mb_start :mb_end ]
275+ elif isinstance (value , dict ):
276+ # Recursively slice dict values (e.g., ref_outputs: {"logps": tensor})
277+ return {
278+ k : MegatronModel ._slice_value_for_microbatch (v , mb_start , mb_end , micro_batch_size )
279+ for k , v in value .items ()
280+ }
281+ else :
282+ # Scalars, small tensors, or non-sliceable values pass through as-is
283+ return value
284+
253285 def _postprocess_tensor_cp (self , tensor ):
254286 """All-gather and reconstruct full sequence from CP-split tensor.
255287
@@ -401,8 +433,6 @@ def forward_backward(self,
401433 else :
402434 seq_length = original_seq_length
403435
404- if 'ref_outputs' in kwargs :
405- breakpoint ()
406436 num_microbatches = len (inputs )
407437 loss_extra_kwargs_per_mb = []
408438 if num_microbatches <= 1 :
@@ -411,17 +441,10 @@ def forward_backward(self,
411441 for mb_idx in range (num_microbatches ):
412442 mb_start = mb_idx * micro_batch_size
413443 mb_end = mb_start + micro_batch_size
414- mb_kwargs = {}
415- for key , value in kwargs .items ():
416- if isinstance (value , torch .Tensor ) and value .dim () >= 1 and value .shape [0 ] > micro_batch_size :
417- mb_kwargs [key ] = value [mb_start :mb_end ]
418- elif isinstance (value , np .ndarray ) and value .ndim >= 1 and value .shape [0 ] > micro_batch_size :
419- mb_kwargs [key ] = value [mb_start :mb_end ]
420- elif isinstance (value , (list , tuple )) and len (value ) > micro_batch_size :
421- mb_kwargs [key ] = value [mb_start :mb_end ]
422- else :
423- # Scalars, small tensors, or non-sliceable values pass through as-is
424- mb_kwargs [key ] = value
444+ mb_kwargs = {
445+ key : self ._slice_value_for_microbatch (value , mb_start , mb_end , micro_batch_size )
446+ for key , value in kwargs .items ()
447+ }
425448 loss_extra_kwargs_per_mb .append (mb_kwargs )
426449
427450 _mb_counter = [0 ] # mutable counter for closure
0 commit comments