-
Notifications
You must be signed in to change notification settings - Fork 121
Open
Description
Dear authors,
Training using the following script:
python src/tevatron/retriever/driver/train.py \
--do_train \
--fp16 \
--per_device_train_batch_size 32 \
--learning_rate 1e-5 \
--num_train_epochs 1 \
--attn_implementation sdpa \
--dataset_name Tevatron/scifact \
--model_name_or_path bert-base-uncased \
--output_dir $OUTPUT_DIR/$EXP_NAME \
--overwrite_output_dirTraining fails with the following error:
Traceback (most recent call last):
File "/home/thuy0050/code/tevatron/src/tevatron/retriever/driver/train.py", line 113, in <module>
main()
File "/home/thuy0050/code/tevatron/src/tevatron/retriever/driver/train.py", line 106, in main
trainer.train(resume_from_checkpoint=(last_checkpoint is not None))
File "/scratch/ft49/thuy0050/miniconda/conda/envs/tevatron/lib/python3.10/site-packages/transformers/trainer.py", line 2240, in train
return inner_training_loop(
File "/scratch/ft49/thuy0050/miniconda/conda/envs/tevatron/lib/python3.10/site-packages/transformers/trainer.py", line 2588, in _inner_training_loop
_grad_norm = self.accelerator.clip_grad_norm_(
File "/scratch/ft49/thuy0050/miniconda/conda/envs/tevatron/lib/python3.10/site-packages/accelerate/accelerator.py", line 2628, in clip_grad_norm_
self.unscale_gradients()
File "/scratch/ft49/thuy0050/miniconda/conda/envs/tevatron/lib/python3.10/site-packages/accelerate/accelerator.py", line 2567, in unscale_gradients
self.scaler.unscale_(opt)
File "/scratch/ft49/thuy0050/miniconda/conda/envs/tevatron/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 342, in unscale_
optimizer_state["found_inf_per_device"] = self._unscale_grads_(
File "/scratch/ft49/thuy0050/miniconda/conda/envs/tevatron/lib/python3.10/site-packages/torch/amp/grad_scaler.py", line 264, in _unscale_grads_
raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.- This issue is caused by the --fp16 flag, which internally sets
torch_dtype=torch.float16when calling theDenseModel.build()function. This function usesAutoModel.from_pretrainedto load the model. - Removing
torch_dtypefromDenseModel.build()seems to solve the problem. - Additionally, using AutoModel.from_pretrained, training
bert-base-uncasedrequires specifying--attn_implementation sdpato ensure compatibility with the attention backend.
The environment is:
datasets==3.6.0
faiss_cpu==1.11.0
numpy==2.2.6
peft==0.11.1
pyserini==0.44.0
torch==2.7.0+cu128
transformers==4.52.4
vllm==0.9.0.1Best regards,
Louis
Metadata
Metadata
Assignees
Labels
No labels