@@ -362,47 +362,58 @@ def _is_trajectory(self, obj: Any) -> bool:
362362 """Check if an object is a Trajectory (has 'messages' key)."""
363363 return isinstance (obj , Mapping ) and 'messages' in obj
364364
365- def _get_trajectory_keys (self , columnar : Mapping ) -> List [str ]:
366- """Get keys whose values are lists of Trajectories in columnar format ."""
365+ def _get_trajectory_keys (self , trajectories : Mapping , is_columnar : bool ) -> List [str ]:
366+ """Get keys whose values are lists of Trajectories."""
367367 keys = []
368- for k , v in columnar .items ():
369- if isinstance (v , list ) and v and self ._is_trajectory (v [0 ]):
370- keys .append (k )
368+ if is_columnar :
369+ for k , v in trajectories .items ():
370+ if isinstance (v , list ) and v and self ._is_trajectory (v [0 ]):
371+ keys .append (k )
372+ else :
373+ for k , v in trajectories .items ():
374+ if v and self ._is_trajectory (v ):
375+ keys .append (k )
371376 return keys
372377
373378 def batch_encode (
374379 self ,
375- trajectories : Union [Dict [str , Any ], List [Trajectory ]],
380+ trajectories : Union [Dict [str , List [ Any ] ], List [Trajectory ]],
376381 add_generation_prompt : bool = False ,
377382 ) -> Union [Dict [str , Any ], List [InputFeature ]]:
378383 """Encode trajectories into InputFeatures.
379384
380385 Args:
381386 trajectories: Either List[Trajectory] or columnar Dict[str, List].
382- For DPO , columnar format with 'positive'/'negative' keys containing
383- List[Trajectory] is supported.
387+ For nested trajectories , columnar format with trajectory list columns
388+ (e.g., 'chosen'/'rejected') is supported.
384389 add_generation_prompt: Whether to add generation prompt.
385390
386391 Returns:
387392 List[InputFeature] or columnar Dict[str, List[InputFeature]].
388393 """
389394 _transfer = False
390395
396+ # Handle list input
397+ if isinstance (trajectories , list ) and len (trajectories ) > 0 :
398+ # Check if first element has nested trajectories
399+ if isinstance (trajectories [0 ], Mapping ) and len (self ._get_trajectory_keys (trajectories [0 ], False )) > 0 :
400+ # Convert row→columnar, process with columnar logic, convert back
401+ columnar = self .map_row_to_col (trajectories )
402+ encoded = self .batch_encode (columnar , add_generation_prompt )
403+ return self .map_col_to_row (encoded )
404+
391405 if isinstance (trajectories , Mapping ):
392406 _transfer = True
393- # Check if it has trajectory list columns (DPO format)
394- traj_keys = self ._get_trajectory_keys (trajectories )
407+ # Check if it has nested trajectory columns
408+ traj_keys = self ._get_trajectory_keys (trajectories , True )
395409 if traj_keys :
396- # DPO format: encode each trajectory list separately, keep other columns
397- result = {}
398- for key in trajectories :
399- if key in traj_keys :
400- # Encode this trajectory list
401- result [key ] = self .batch_encode (trajectories [key ], add_generation_prompt = add_generation_prompt )
402- else :
403- # Keep non-trajectory columns as-is
404- result [key ] = trajectories [key ]
405- return result
410+ # Nested format: encode each trajectory list separately, keep other columns
411+ return {
412+ key :
413+ self .batch_encode (trajectories [key ], add_generation_prompt )
414+ if key in traj_keys else trajectories [key ]
415+ for key in trajectories
416+ }
406417 else :
407418 # Standard columnar format
408419 trajectories = self .map_col_to_row (trajectories )
0 commit comments