Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions src/twinkle/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,47 +362,58 @@ def _is_trajectory(self, obj: Any) -> bool:
"""Check if an object is a Trajectory (has 'messages' key)."""
return isinstance(obj, Mapping) and 'messages' in obj

def _get_trajectory_keys(self, columnar: Mapping) -> List[str]:
"""Get keys whose values are lists of Trajectories in columnar format."""
def _get_trajectory_keys(self, trajectories: Mapping, is_columnar: bool) -> List[str]:
"""Get keys whose values are lists of Trajectories."""
keys = []
for k, v in columnar.items():
if isinstance(v, list) and v and self._is_trajectory(v[0]):
keys.append(k)
if is_columnar:
for k, v in trajectories.items():
if isinstance(v, list) and v and self._is_trajectory(v[0]):
keys.append(k)
else:
for k, v in trajectories.items():
if v and self._is_trajectory(v):
keys.append(k)
return keys

def batch_encode(
self,
trajectories: Union[Dict[str, Any], List[Trajectory]],
trajectories: Union[Dict[str, List[Any]], List[Trajectory]],
add_generation_prompt: bool = False,
) -> Union[Dict[str, Any], List[InputFeature]]:
"""Encode trajectories into InputFeatures.

Args:
trajectories: Either List[Trajectory] or columnar Dict[str, List].
For DPO, columnar format with 'positive'/'negative' keys containing
List[Trajectory] is supported.
For nested trajectories, columnar format with trajectory list columns
(e.g., 'chosen'/'rejected') is supported.
add_generation_prompt: Whether to add generation prompt.

Returns:
List[InputFeature] or columnar Dict[str, List[InputFeature]].
"""
_transfer = False

# Handle list input
if isinstance(trajectories, list) and len(trajectories) > 0:
# Check if first element has nested trajectories
if isinstance(trajectories[0], Mapping) and len(self._get_trajectory_keys(trajectories[0], False)) > 0:
# Convert row→columnar, process with columnar logic, convert back
columnar = self.map_row_to_col(trajectories)
encoded = self.batch_encode(columnar, add_generation_prompt)
return self.map_col_to_row(encoded)

if isinstance(trajectories, Mapping):
_transfer = True
# Check if it has trajectory list columns (DPO format)
traj_keys = self._get_trajectory_keys(trajectories)
# Check if it has nested trajectory columns
traj_keys = self._get_trajectory_keys(trajectories, True)
if traj_keys:
# DPO format: encode each trajectory list separately, keep other columns
result = {}
for key in trajectories:
if key in traj_keys:
# Encode this trajectory list
result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt)
else:
# Keep non-trajectory columns as-is
result[key] = trajectories[key]
return result
# Nested format: encode each trajectory list separately, keep other columns
return {
key:
self.batch_encode(trajectories[key], add_generation_prompt)
if key in traj_keys else trajectories[key]
for key in trajectories
}
else:
# Standard columnar format
trajectories = self.map_col_to_row(trajectories)
Expand Down
Loading