Skip to content

Commit d5d2832

Browse files
committed
fix(test): pass inputs as List[InputFeature] to forward_backward
1 parent 983cdbc commit d5d2832

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

tests/strategy/test_fsdp2_memory_efficient_init.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -326,12 +326,13 @@ def _worker_e2e_memory_efficient(rank, world_size, port, model_path):
326326
)
327327
model.set_optimizer('AdamW', lr=1e-4)
328328

329-
# Create a dummy batch
330-
batch = {
331-
'input_ids': torch.randint(0, 1000, (1, 16)).to(_DEVICE_TYPE),
332-
'labels': torch.randint(0, 1000, (1, 16)).to(_DEVICE_TYPE),
333-
'attention_mask': torch.ones(1, 16, dtype=torch.long).to(_DEVICE_TYPE),
334-
}
329+
# Create a dummy batch — inputs must be a list of dicts (List[InputFeature]).
330+
# The processor's to_tensor() handles device placement internally.
331+
batch = [{
332+
'input_ids': torch.randint(0, 1000, (16,)),
333+
'labels': torch.randint(0, 1000, (16,)),
334+
'attention_mask': torch.ones(16, dtype=torch.long),
335+
}]
335336

336337
# This triggers _lazy_wrap_model → wrap_model(memory_efficient=True)
337338
model.forward_backward(inputs=batch)

0 commit comments

Comments
 (0)