Skip to content

Commit a89ede5

Browse files
fix dpo with lazy_dataset (#136)
1 parent 605dc89 commit a89ede5

File tree

1 file changed

+31
-20
lines changed

1 file changed

+31
-20
lines changed

src/twinkle/template/base.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -362,47 +362,58 @@ def _is_trajectory(self, obj: Any) -> bool:
362362
"""Check if an object is a Trajectory (has 'messages' key)."""
363363
return isinstance(obj, Mapping) and 'messages' in obj
364364

365-
def _get_trajectory_keys(self, columnar: Mapping) -> List[str]:
366-
"""Get keys whose values are lists of Trajectories in columnar format."""
365+
def _get_trajectory_keys(self, trajectories: Mapping, is_columnar: bool) -> List[str]:
366+
"""Get keys whose values are lists of Trajectories."""
367367
keys = []
368-
for k, v in columnar.items():
369-
if isinstance(v, list) and v and self._is_trajectory(v[0]):
370-
keys.append(k)
368+
if is_columnar:
369+
for k, v in trajectories.items():
370+
if isinstance(v, list) and v and self._is_trajectory(v[0]):
371+
keys.append(k)
372+
else:
373+
for k, v in trajectories.items():
374+
if v and self._is_trajectory(v):
375+
keys.append(k)
371376
return keys
372377

373378
def batch_encode(
374379
self,
375-
trajectories: Union[Dict[str, Any], List[Trajectory]],
380+
trajectories: Union[Dict[str, List[Any]], List[Trajectory]],
376381
add_generation_prompt: bool = False,
377382
) -> Union[Dict[str, Any], List[InputFeature]]:
378383
"""Encode trajectories into InputFeatures.
379384
380385
Args:
381386
trajectories: Either List[Trajectory] or columnar Dict[str, List].
382-
For DPO, columnar format with 'positive'/'negative' keys containing
383-
List[Trajectory] is supported.
387+
For nested trajectories, columnar format with trajectory list columns
388+
(e.g., 'chosen'/'rejected') is supported.
384389
add_generation_prompt: Whether to add generation prompt.
385390
386391
Returns:
387392
List[InputFeature] or columnar Dict[str, List[InputFeature]].
388393
"""
389394
_transfer = False
390395

396+
# Handle list input
397+
if isinstance(trajectories, list) and len(trajectories) > 0:
398+
# Check if first element has nested trajectories
399+
if isinstance(trajectories[0], Mapping) and len(self._get_trajectory_keys(trajectories[0], False)) > 0:
400+
# Convert row→columnar, process with columnar logic, convert back
401+
columnar = self.map_row_to_col(trajectories)
402+
encoded = self.batch_encode(columnar, add_generation_prompt)
403+
return self.map_col_to_row(encoded)
404+
391405
if isinstance(trajectories, Mapping):
392406
_transfer = True
393-
# Check if it has trajectory list columns (DPO format)
394-
traj_keys = self._get_trajectory_keys(trajectories)
407+
# Check if it has nested trajectory columns
408+
traj_keys = self._get_trajectory_keys(trajectories, True)
395409
if traj_keys:
396-
# DPO format: encode each trajectory list separately, keep other columns
397-
result = {}
398-
for key in trajectories:
399-
if key in traj_keys:
400-
# Encode this trajectory list
401-
result[key] = self.batch_encode(trajectories[key], add_generation_prompt=add_generation_prompt)
402-
else:
403-
# Keep non-trajectory columns as-is
404-
result[key] = trajectories[key]
405-
return result
410+
# Nested format: encode each trajectory list separately, keep other columns
411+
return {
412+
key:
413+
self.batch_encode(trajectories[key], add_generation_prompt)
414+
if key in traj_keys else trajectories[key]
415+
for key in trajectories
416+
}
406417
else:
407418
# Standard columnar format
408419
trajectories = self.map_col_to_row(trajectories)

0 commit comments

Comments
 (0)