Skip to content

Commit 6cfcdae

Browse files
Ahmed Khaledcopybara-github
authored andcommitted
internal
PiperOrigin-RevId: 885758453
1 parent 1fc1409 commit 6cfcdae

File tree

1 file changed

+22
-6
lines changed

1 file changed

+22
-6
lines changed

init2winit/main_config_flags.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
# For internal compatibility reasons, we need to pull this function out.
4747
makedirs = tf.io.gfile.makedirs
4848

49+
4950
# Allow caching for any size executables since we had a small executable that
5051
# took 50 minutes to compile.
5152
jax.config.update('jax_persistent_cache_min_entry_size_bytes', -1)
@@ -78,6 +79,17 @@
7879
FLAGS = flags.FLAGS
7980

8081

82+
def _get_trial_dir(experiment_dir, trial_id):
83+
"""Returns the trial directory path.
84+
85+
Args:
86+
experiment_dir: The parent experiment directory.
87+
trial_id: For single-point runs, this is the worker_id. For Vizier sweeps,
88+
this is the vizier trial id (which we call trial_handle).
89+
"""
90+
return os.path.join(experiment_dir, str(trial_id))
91+
92+
8193
def _write_trial_meta_data(meta_data_path, meta_data):
8294
d = meta_data.copy()
8395
d['timestamp'] = time.time()
@@ -176,7 +188,7 @@ def _run(
176188
metrics_name,
177189
compile_init_on_cpu=compile_init_on_cpu,
178190
)
179-
trial_dir = os.path.join(experiment_dir, str(worker_id))
191+
trial_dir = _get_trial_dir(experiment_dir, worker_id)
180192
meta_data_path = os.path.join(trial_dir, 'meta_data.json')
181193
meta_data = {'worker_id': worker_id, 'status': 'incomplete'}
182194
if jax.process_index() == 0:
@@ -262,12 +274,16 @@ def main(unused_argv):
262274
checkpoint_steps = [int(s.strip()) for s in config.checkpoint_steps]
263275
eval_steps = [int(s.strip()) for s in config.eval_steps]
264276
if jax.process_index() == 0:
265-
makedirs(experiment_dir, mode=0o775)
266-
log_encoding = 'r=3'
267-
log_dir = os.path.join(experiment_dir, log_encoding)
268-
makedirs(log_dir, mode=0o775)
277+
# We don't need to tie the root experiment_dir to any specific worker's
278+
# CNS2 cell, as it's just the parent for the trial directories.
279+
# The trial directories themselves will get the correct cell placement.
280+
kwargs = {}
281+
makedirs(experiment_dir, mode=0o775, **kwargs)
282+
trial_dir = _get_trial_dir(experiment_dir, worker_id)
283+
makedirs(trial_dir, mode=0o775)
269284
log_path = os.path.join(
270-
log_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index()))
285+
trial_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index())
286+
)
271287
with gfile.GFile(log_path, 'a') as logfile:
272288
utils.add_log_file(logfile)
273289
if jax.process_index() == 0:

0 commit comments

Comments
 (0)