|
34 | 34 | import orbax.checkpoint as ocp |
35 | 35 |
|
36 | 36 |
|
| 37 | +CHECKPOINT_TTL = 'ttl=180d' |
| 38 | + |
| 39 | + |
37 | 40 | class BaseTrainer(metaclass=abc.ABCMeta): |
38 | 41 | """Abstract parent class for all trainers.""" |
39 | 42 |
|
@@ -142,6 +145,7 @@ def __init__( |
142 | 145 | del loss_name |
143 | 146 | del metrics_name |
144 | 147 | self._train_dir = train_dir |
| 148 | + self._checkpoint_dir = os.path.join(train_dir, CHECKPOINT_TTL) |
145 | 149 | self._model = model |
146 | 150 | self._dataset_builder = dataset_builder |
147 | 151 | self._data_selector = data_selector |
@@ -200,7 +204,9 @@ def __init__( |
200 | 204 |
|
201 | 205 | # Only used if checkpoints_steps is non-empty. Standard checkpoints are |
202 | 206 | # saved in train_dir. |
203 | | - self._checkpoint_dir = os.path.join(self._train_dir, 'checkpoints') |
| 207 | + self._extra_checkpoint_dir = os.path.join( |
| 208 | + self._checkpoint_dir, 'checkpoints' |
| 209 | + ) |
204 | 210 |
|
205 | 211 | # During eval, we can donate the 'batch' buffer. We don't donate the |
206 | 212 | # 'params' and 'batch_stats' buffers as we don't re-assign those values in |
@@ -256,7 +262,7 @@ def maybe_restore_from_checkpoint(self, |
256 | 262 | unreplicated_params, |
257 | 263 | unreplicated_batch_stats, |
258 | 264 | unreplicated_metrics_state, |
259 | | - train_dir=self._train_dir, |
| 265 | + train_dir=self._checkpoint_dir, |
260 | 266 | external_checkpoint_path=self._external_checkpoint_path, |
261 | 267 | orbax_checkpointer=self._orbax_checkpointer, |
262 | 268 | ) |
@@ -405,7 +411,7 @@ def _eval(self, start_step, start_time, save=True): |
405 | 411 |
|
406 | 412 | Has the side-effects of: |
407 | 413 | - synchronizing self._batch_stats across hosts |
408 | | - - checkpointing via self._save(self._train_dir) |
| 414 | + - checkpointing via self._save(self._checkpoint_dir) |
409 | 415 | - resetting self._sum_train_cost to jnp.zeros |
410 | 416 | - resetting self._time_at_prev_eval_end to the current time |
411 | 417 | - resetting self._prev_eval_step to self._global_step |
@@ -440,7 +446,7 @@ def _eval(self, start_step, start_time, save=True): |
440 | 446 | ) |
441 | 447 | self._run_eval_callbacks(report) |
442 | 448 | if save: |
443 | | - self._save(self._train_dir) |
| 449 | + self._save(self._checkpoint_dir) |
444 | 450 | steps_since_last_eval = self._global_step - self._prev_eval_step |
445 | 451 | steps_per_sec_no_eval = steps_since_last_eval / time_since_last_eval |
446 | 452 | run_time = time.time() - self._time_at_prev_eval_end |
@@ -635,7 +641,7 @@ def train(self): |
635 | 641 | self._prev_eval_step = self._global_step |
636 | 642 |
|
637 | 643 | if self._global_step in self._checkpoint_steps: |
638 | | - self._save(self._checkpoint_dir, max_to_keep=None) |
| 644 | + self._save(self._extra_checkpoint_dir, max_to_keep=None) |
639 | 645 |
|
640 | 646 | for _ in range(start_step, self._num_train_steps): |
641 | 647 | with jax.profiler.StepTraceAnnotation( |
@@ -671,7 +677,7 @@ def train(self): |
671 | 677 | self._sum_train_cost, |
672 | 678 | ) |
673 | 679 | if self._global_step in self._checkpoint_steps: |
674 | | - self._save(self._checkpoint_dir, max_to_keep=None) |
| 680 | + self._save(self._extra_checkpoint_dir, max_to_keep=None) |
675 | 681 |
|
676 | 682 | # TODO(gdahl, gilmer): consider moving this test up. |
677 | 683 | # NB: Since this test is after we increment self._global_step, having 0 |
|
0 commit comments