Skip to content

Commit 8ae58e6

Browse files
committed
update handler
1 parent 8220365 commit 8ae58e6

File tree

1 file changed

+8
-15
lines changed

1 file changed

+8
-15
lines changed

src/twinkle/server/common/datum.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,20 +133,13 @@ def input_feature_to_datum(input_feature: InputFeature) -> types.Datum:
133133

134134
# 3. Optionally pack multimodal tensors into loss_fn_inputs so that
135135
# the server-side ``datum_to_input_feature`` can restore them.
136-
if 'pixel_values' in input_feature and input_feature['pixel_values'] is not None:
137-
pixel_values = input_feature['pixel_values']
138-
if hasattr(pixel_values, 'detach'):
139-
pixel_values = pixel_values.detach().cpu().numpy()
140-
elif not isinstance(pixel_values, np.ndarray):
141-
pixel_values = np.asarray(pixel_values)
142-
loss_fn_inputs['pixel_values'] = types.TensorData.from_numpy(pixel_values)
143-
144-
if 'image_grid_thw' in input_feature and input_feature['image_grid_thw'] is not None:
145-
image_grid_thw = input_feature['image_grid_thw']
146-
if hasattr(image_grid_thw, 'detach'):
147-
image_grid_thw = image_grid_thw.detach().cpu().numpy()
148-
elif not isinstance(image_grid_thw, np.ndarray):
149-
image_grid_thw = np.asarray(image_grid_thw)
150-
loss_fn_inputs['image_grid_thw'] = types.TensorData.from_numpy(image_grid_thw)
136+
for key in ('pixel_values', 'image_grid_thw'):
137+
if key in input_feature and input_feature[key] is not None:
138+
value = input_feature[key]
139+
if hasattr(value, 'detach'):
140+
value = value.detach().cpu().numpy()
141+
elif not isinstance(value, np.ndarray):
142+
value = np.asarray(value)
143+
loss_fn_inputs[key] = types.TensorData.from_numpy(value)
151144

152145
return types.Datum(loss_fn_inputs=loss_fn_inputs, model_input=model_input)

0 commit comments

Comments
 (0)