-
Notifications
You must be signed in to change notification settings - Fork 1
Open
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or request
Description
When running the training loop, the first iteration usually takes the longest because JAX needs to compile the make_step function. When using the tqdm progress bar to show training progress, it significantly skews the estimated remaining training time, especially at the beginning of training.
Proposed solutions:
- JAX warmup step. Simply do one (wasted) step before the training loop to compile
make_stepbefore tqdm even notices. - Reset the tqdm internals after the first step (seems sketchy).
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't workingenhancementNew feature or requestNew feature or request