Skip to content

Commit 8c20140

Browse files
georgedahlcopybara-github
authored andcommitted
Changing convention on checkpoint paths.
PiperOrigin-RevId: 844876403
1 parent 9a62e1b commit 8c20140

File tree

4 files changed

+14
-21
lines changed

4 files changed

+14
-21
lines changed

init2winit/trainer_lib/base_trainer.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434
import orbax.checkpoint as ocp
3535

3636

37+
CHECKPOINT_TTL = 'ttl=180d'
38+
39+
3740
class BaseTrainer(metaclass=abc.ABCMeta):
3841
"""Abstract parent class for all trainers."""
3942

@@ -142,6 +145,7 @@ def __init__(
142145
del loss_name
143146
del metrics_name
144147
self._train_dir = train_dir
148+
self._checkpoint_dir = os.path.join(train_dir, CHECKPOINT_TTL)
145149
self._model = model
146150
self._dataset_builder = dataset_builder
147151
self._data_selector = data_selector
@@ -200,7 +204,9 @@ def __init__(
200204

201205
# Only used if checkpoints_steps is non-empty. Standard checkpoints are
202206
# saved in train_dir.
203-
self._checkpoint_dir = os.path.join(self._train_dir, 'checkpoints')
207+
self._extra_checkpoint_dir = os.path.join(
208+
self._checkpoint_dir, 'checkpoints'
209+
)
204210

205211
# During eval, we can donate the 'batch' buffer. We don't donate the
206212
# 'params' and 'batch_stats' buffers as we don't re-assign those values in
@@ -256,7 +262,7 @@ def maybe_restore_from_checkpoint(self,
256262
unreplicated_params,
257263
unreplicated_batch_stats,
258264
unreplicated_metrics_state,
259-
train_dir=self._train_dir,
265+
train_dir=self._checkpoint_dir,
260266
external_checkpoint_path=self._external_checkpoint_path,
261267
orbax_checkpointer=self._orbax_checkpointer,
262268
)
@@ -405,7 +411,7 @@ def _eval(self, start_step, start_time, save=True):
405411
406412
Has the side-effects of:
407413
- synchronizing self._batch_stats across hosts
408-
- checkpointing via self._save(self._train_dir)
414+
- checkpointing via self._save(self._checkpoint_dir)
409415
- resetting self._sum_train_cost to jnp.zeros
410416
- resetting self._time_at_prev_eval_end to the current time
411417
- resetting self._prev_eval_step to self._global_step
@@ -440,7 +446,7 @@ def _eval(self, start_step, start_time, save=True):
440446
)
441447
self._run_eval_callbacks(report)
442448
if save:
443-
self._save(self._train_dir)
449+
self._save(self._checkpoint_dir)
444450
steps_since_last_eval = self._global_step - self._prev_eval_step
445451
steps_per_sec_no_eval = steps_since_last_eval / time_since_last_eval
446452
run_time = time.time() - self._time_at_prev_eval_end
@@ -635,7 +641,7 @@ def train(self):
635641
self._prev_eval_step = self._global_step
636642

637643
if self._global_step in self._checkpoint_steps:
638-
self._save(self._checkpoint_dir, max_to_keep=None)
644+
self._save(self._extra_checkpoint_dir, max_to_keep=None)
639645

640646
for _ in range(start_step, self._num_train_steps):
641647
with jax.profiler.StepTraceAnnotation(
@@ -671,7 +677,7 @@ def train(self):
671677
self._sum_train_cost,
672678
)
673679
if self._global_step in self._checkpoint_steps:
674-
self._save(self._checkpoint_dir, max_to_keep=None)
680+
self._save(self._extra_checkpoint_dir, max_to_keep=None)
675681

676682
# TODO(gdahl, gilmer): consider moving this test up.
677683
# NB: Since this test is after we increment self._global_step, having 0

init2winit/trainer_lib/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -738,7 +738,7 @@ def as_dataset(self, *args, **kwargs):
738738
epoch_reports = list(self.trainer.train())
739739

740740
# check that the additional checkpoints are saved.
741-
checkpoint_dir = os.path.join(self.test_dir, 'checkpoints')
741+
checkpoint_dir = os.path.join(self.test_dir, 'ttl=180d', 'checkpoints')
742742
saved_steps = []
743743
for f in tf.io.gfile.listdir(checkpoint_dir):
744744
if f[:5] == 'ckpt_':

init2winit/trainer_lib/trainer_utils.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import time
2020

2121
from absl import logging
22-
from flax import jax_utils
2322
from init2winit import utils
2423
from init2winit.dataset_lib import data_utils
2524
import jax
@@ -76,18 +75,6 @@ def log_epoch_report(report, metrics_logger):
7675
report['epoch'])
7776

7877

79-
def maybe_log_training_metrics(metrics_state,
80-
metrics_summary_fn,
81-
metrics_logger):
82-
"""If appropriate, send a summary tree of training metrics to the logger."""
83-
if metrics_state:
84-
unreplicated_metrics_state = jax_utils.unreplicate(metrics_state)
85-
summary_tree = metrics_summary_fn(unreplicated_metrics_state)
86-
metrics_logger.append_pytree(summary_tree)
87-
metrics_logger.write_pytree(unreplicated_metrics_state,
88-
prefix='metrics_state')
89-
90-
9178
def should_eval(global_step, eval_frequency, eval_steps):
9279
on_step = eval_steps and global_step in eval_steps
9380
on_freq = (global_step % eval_frequency == 0)

init2winit/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def wrapper(*args, **kwargs):
105105
def set_up_loggers(train_dir, xm_work_unit=None):
106106
"""Creates a logger for eval metrics as well as initialization metrics."""
107107
csv_path = os.path.join(train_dir, 'measurements.csv')
108-
pytree_path = os.path.join(train_dir, 'training_metrics')
108+
pytree_path = os.path.join(train_dir, 'ttl=180d', 'training_metrics')
109109
metrics_logger = MetricLogger(
110110
csv_path=csv_path,
111111
pytree_path=pytree_path,

0 commit comments

Comments
 (0)