From bac261a74b8403177db8a10743f51002ac672634 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 20:54:57 +0800 Subject: [PATCH 1/2] fix --- src/twinkle/template/base.py | 51 ++++++++++++++++++++++-------------- 1 file changed, 31 insertions(+), 20 deletions(-) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index acc17d2e..7eb6ef88 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -362,25 +362,30 @@ def _is_trajectory(self, obj: Any) -> bool: """Check if an object is a Trajectory (has 'messages' key).""" return isinstance(obj, Mapping) and 'messages' in obj - def _get_trajectory_keys(self, columnar: Mapping) -> List[str]: - """Get keys whose values are lists of Trajectories in columnar format.""" + def _get_trajectory_keys(self, trajectories: Mapping, is_columar: bool) -> List[str]: + """Get keys whose values are lists of Trajectories.""" keys = [] - for k, v in columnar.items(): - if isinstance(v, list) and v and self._is_trajectory(v[0]): - keys.append(k) + if is_columar: + for k, v in trajectories.items(): + if isinstance(v, list) and v and self._is_trajectory(v[0]): + keys.append(k) + else: + for k, v in trajectories.items(): + if isinstance(v, dict) and v and self._is_trajectory(v): + keys.append(k) return keys def batch_encode( self, - trajectories: Union[Dict[str, Any], List[Trajectory]], + trajectories: Union[Dict[str, List[Any]], List[Trajectory]], add_generation_prompt: bool = False, ) -> Union[Dict[str, Any], List[InputFeature]]: """Encode trajectories into InputFeatures. Args: trajectories: Either List[Trajectory] or columnar Dict[str, List]. - For DPO, columnar format with 'positive'/'negative' keys containing - List[Trajectory] is supported. + For nested trajectories, columnar format with trajectory list columns + (e.g., 'chosen'/'rejected') is supported. add_generation_prompt: Whether to add generation prompt. Returns: @@ -388,21 +393,27 @@ def batch_encode( """ _transfer = False + # Handle list input + if isinstance(trajectories, list) and len(trajectories) > 0: + # Check if first element has nested trajectories + if isinstance(trajectories[0], Mapping) and len(self._get_trajectory_keys(trajectories[0], False)) > 0: + # Convert row→columnar, process with columnar logic, convert back + columnar = self.map_row_to_col(trajectories) + encoded = self.batch_encode(columnar, add_generation_prompt) + return self.map_col_to_row(encoded) + if isinstance(trajectories, Mapping): _transfer = True - # Check if it has trajectory list columns (DPO format) - traj_keys = self._get_trajectory_keys(trajectories) + # Check if it has nested trajectory columns + traj_keys = self._get_trajectory_keys(trajectories, True) if traj_keys: - # DPO format: encode each trajectory list separately, keep other columns - result = {} - for key in trajectories: - if key in traj_keys: - # Encode this trajectory list - result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt) - else: - # Keep non-trajectory columns as-is - result[key] = trajectories[key] - return result + # Nested format: encode each trajectory list separately, keep other columns + return { + key: + self.batch_encode(trajectories[key], add_generation_prompt) + if key in traj_keys else trajectories[key] + for key in trajectories + } else: # Standard columnar format trajectories = self.map_col_to_row(trajectories) From 77d457c6355f453a63a812ad90eca819301f6f32 Mon Sep 17 00:00:00 2001 From: tastelikefeet Date: Tue, 31 Mar 2026 22:00:01 +0800 Subject: [PATCH 2/2] fix --- src/twinkle/template/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 7eb6ef88..b33ad449 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -362,16 +362,16 @@ def _is_trajectory(self, obj: Any) -> bool: """Check if an object is a Trajectory (has 'messages' key).""" return isinstance(obj, Mapping) and 'messages' in obj - def _get_trajectory_keys(self, trajectories: Mapping, is_columar: bool) -> List[str]: + def _get_trajectory_keys(self, trajectories: Mapping, is_columnar: bool) -> List[str]: """Get keys whose values are lists of Trajectories.""" keys = [] - if is_columar: + if is_columnar: for k, v in trajectories.items(): if isinstance(v, list) and v and self._is_trajectory(v[0]): keys.append(k) else: for k, v in trajectories.items(): - if isinstance(v, dict) and v and self._is_trajectory(v): + if v and self._is_trajectory(v): keys.append(k) return keys