diff --git a/the_well/benchmark/trainer/training.py b/the_well/benchmark/trainer/training.py index 772dfa41..c7200083 100644 --- a/the_well/benchmark/trainer/training.py +++ b/the_well/benchmark/trainer/training.py @@ -460,6 +460,7 @@ def train(self): self.save_model( epoch, val_loss, os.path.join(self.checkpoint_folder, "best.pt") ) + self.best_val_loss = val_loss # Check if time for expensive validation - periodic or final if epoch % self.rollout_val_frequency == 0 or (epoch == self.max_epoch): logger.info(