Skip to content

Commit e543cdd

Browse files
fix video mm (#105)
1 parent 76a20a3 commit e543cdd

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

src/twinkle/processor/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,13 @@ def to_transformers_dict(inputs: List[InputFeature], **kwargs) -> List[InputFeat
311311
for _input in inputs:
312312
output = {}
313313
_keys = [
314-
'input_ids', 'input_embeddings', 'attention_mask', 'position_ids', 'labels', 'completion_mask',
315-
'pixel_values', 'image_grid_thw'
316-
]
314+
'input_ids',
315+
'input_embeddings',
316+
'attention_mask',
317+
'position_ids',
318+
'labels',
319+
'completion_mask',
320+
] + list(InputProcessor.VLM_CONCAT_FIELDS)
317321
for key in list(_input.keys()):
318322
if key in _keys:
319323
output[key] = np.array(_input[key]) if not isinstance(_input[key], torch.Tensor) else _input[key]

src/twinkle/template/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]:
226226
message['images'] = self.preprocess_images(msg_images)
227227
assert len(message['images']) == content.count(self.image_placeholder)
228228
if msg_videos:
229-
message['videos'] = self.preprocess_images(msg_videos)
229+
message['videos'] = self.preprocess_videos(msg_videos)
230230
assert len(message['videos']) == content.count(self.video_placeholder)
231231
if msg_audios:
232232
message['audios'] = self.preprocess_audios(msg_audios)

0 commit comments

Comments
 (0)