-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
Hi team - a question about the following line of code:
Line 144 in b4b7b97
| mask_prob = self.args.mask_prob + (.8 - self.args.mask_prob) * (epoch - 1) / 20 |
Can I ask what function this serves, and why? If resuming training from a checkpoint model where the epoch count exceeds 28, this is greater than 1 and causes assertion fails. I note that if starting training from epoch 0 with no interruption, the mask_prob remains fixed at the initialized value (0.2), but with any interruption the training resumes with a higher mask prob.
Should this just be mask_prob = self.args.mask_prob ?
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels