diff --git a/init2winit/main_config_flags.py b/init2winit/main_config_flags.py index f0fe83e4..9a1f270b 100644 --- a/init2winit/main_config_flags.py +++ b/init2winit/main_config_flags.py @@ -46,6 +46,7 @@ # For internal compatibility reasons, we need to pull this function out. makedirs = tf.io.gfile.makedirs + # Allow caching for any size executables since we had a small executable that # took 50 minutes to compile. jax.config.update('jax_persistent_cache_min_entry_size_bytes', -1) @@ -78,6 +79,17 @@ FLAGS = flags.FLAGS +def _get_trial_dir(experiment_dir, trial_id): + """Returns the trial directory path. + + Args: + experiment_dir: The parent experiment directory. + trial_id: For single-point runs, this is the worker_id. For Vizier sweeps, + this is the vizier trial id (which we call trial_handle). + """ + return os.path.join(experiment_dir, str(trial_id)) + + def _write_trial_meta_data(meta_data_path, meta_data): d = meta_data.copy() d['timestamp'] = time.time() @@ -176,7 +188,7 @@ def _run( metrics_name, compile_init_on_cpu=compile_init_on_cpu, ) - trial_dir = os.path.join(experiment_dir, str(worker_id)) + trial_dir = _get_trial_dir(experiment_dir, worker_id) meta_data_path = os.path.join(trial_dir, 'meta_data.json') meta_data = {'worker_id': worker_id, 'status': 'incomplete'} if jax.process_index() == 0: @@ -262,12 +274,16 @@ def main(unused_argv): checkpoint_steps = [int(s.strip()) for s in config.checkpoint_steps] eval_steps = [int(s.strip()) for s in config.eval_steps] if jax.process_index() == 0: - makedirs(experiment_dir, mode=0o775) - log_encoding = 'r=3' - log_dir = os.path.join(experiment_dir, log_encoding) - makedirs(log_dir, mode=0o775) + # We don't need to tie the root experiment_dir to any specific worker's + # CNS2 cell, as it's just the parent for the trial directories. + # The trial directories themselves will get the correct cell placement. + kwargs = {} + makedirs(experiment_dir, mode=0o775, **kwargs) + trial_dir = _get_trial_dir(experiment_dir, worker_id) + makedirs(trial_dir, mode=0o775) log_path = os.path.join( - log_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index())) + trial_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index()) + ) with gfile.GFile(log_path, 'a') as logfile: utils.add_log_file(logfile) if jax.process_index() == 0: