diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index acc17d2e..b33ad449 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -362,25 +362,30 @@ def _is_trajectory(self, obj: Any) -> bool: """Check if an object is a Trajectory (has 'messages' key).""" return isinstance(obj, Mapping) and 'messages' in obj - def _get_trajectory_keys(self, columnar: Mapping) -> List[str]: - """Get keys whose values are lists of Trajectories in columnar format.""" + def _get_trajectory_keys(self, trajectories: Mapping, is_columnar: bool) -> List[str]: + """Get keys whose values are lists of Trajectories.""" keys = [] - for k, v in columnar.items(): - if isinstance(v, list) and v and self._is_trajectory(v[0]): - keys.append(k) + if is_columnar: + for k, v in trajectories.items(): + if isinstance(v, list) and v and self._is_trajectory(v[0]): + keys.append(k) + else: + for k, v in trajectories.items(): + if v and self._is_trajectory(v): + keys.append(k) return keys def batch_encode( self, - trajectories: Union[Dict[str, Any], List[Trajectory]], + trajectories: Union[Dict[str, List[Any]], List[Trajectory]], add_generation_prompt: bool = False, ) -> Union[Dict[str, Any], List[InputFeature]]: """Encode trajectories into InputFeatures. Args: trajectories: Either List[Trajectory] or columnar Dict[str, List]. - For DPO, columnar format with 'positive'/'negative' keys containing - List[Trajectory] is supported. + For nested trajectories, columnar format with trajectory list columns + (e.g., 'chosen'/'rejected') is supported. add_generation_prompt: Whether to add generation prompt. Returns: @@ -388,21 +393,27 @@ def batch_encode( """ _transfer = False + # Handle list input + if isinstance(trajectories, list) and len(trajectories) > 0: + # Check if first element has nested trajectories + if isinstance(trajectories[0], Mapping) and len(self._get_trajectory_keys(trajectories[0], False)) > 0: + # Convert row→columnar, process with columnar logic, convert back + columnar = self.map_row_to_col(trajectories) + encoded = self.batch_encode(columnar, add_generation_prompt) + return self.map_col_to_row(encoded) + if isinstance(trajectories, Mapping): _transfer = True - # Check if it has trajectory list columns (DPO format) - traj_keys = self._get_trajectory_keys(trajectories) + # Check if it has nested trajectory columns + traj_keys = self._get_trajectory_keys(trajectories, True) if traj_keys: - # DPO format: encode each trajectory list separately, keep other columns - result = {} - for key in trajectories: - if key in traj_keys: - # Encode this trajectory list - result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) - else: - # Keep non-trajectory columns as-is - result[key] = trajectories[key] - return result + # Nested format: encode each trajectory list separately, keep other columns + return { + key: + self.batch_encode(trajectories[key], add_generation_prompt) + if key in traj_keys else trajectories[key] + for key in trajectories + } else: # Standard columnar format trajectories = self.map_col_to_row(trajectories)