|
46 | 46 | # For internal compatibility reasons, we need to pull this function out. |
47 | 47 | makedirs = tf.io.gfile.makedirs |
48 | 48 |
|
| 49 | + |
49 | 50 | # Allow caching for any size executables since we had a small executable that |
50 | 51 | # took 50 minutes to compile. |
51 | 52 | jax.config.update('jax_persistent_cache_min_entry_size_bytes', -1) |
|
78 | 79 | FLAGS = flags.FLAGS |
79 | 80 |
|
80 | 81 |
|
| 82 | +def _get_trial_dir(experiment_dir, trial_id): |
| 83 | + """Returns the trial directory path. |
| 84 | +
|
| 85 | + Args: |
| 86 | + experiment_dir: The parent experiment directory. |
| 87 | + trial_id: For single-point runs, this is the worker_id. For Vizier sweeps, |
| 88 | + this is the vizier trial id (which we call trial_handle). |
| 89 | + """ |
| 90 | + return os.path.join(experiment_dir, str(trial_id)) |
| 91 | + |
| 92 | + |
81 | 93 | def _write_trial_meta_data(meta_data_path, meta_data): |
82 | 94 | d = meta_data.copy() |
83 | 95 | d['timestamp'] = time.time() |
@@ -176,7 +188,7 @@ def _run( |
176 | 188 | metrics_name, |
177 | 189 | compile_init_on_cpu=compile_init_on_cpu, |
178 | 190 | ) |
179 | | - trial_dir = os.path.join(experiment_dir, str(worker_id)) |
| 191 | + trial_dir = _get_trial_dir(experiment_dir, worker_id) |
180 | 192 | meta_data_path = os.path.join(trial_dir, 'meta_data.json') |
181 | 193 | meta_data = {'worker_id': worker_id, 'status': 'incomplete'} |
182 | 194 | if jax.process_index() == 0: |
@@ -262,12 +274,16 @@ def main(unused_argv): |
262 | 274 | checkpoint_steps = [int(s.strip()) for s in config.checkpoint_steps] |
263 | 275 | eval_steps = [int(s.strip()) for s in config.eval_steps] |
264 | 276 | if jax.process_index() == 0: |
265 | | - makedirs(experiment_dir, mode=0o775) |
266 | | - log_encoding = 'r=3' |
267 | | - log_dir = os.path.join(experiment_dir, log_encoding) |
268 | | - makedirs(log_dir, mode=0o775) |
| 277 | + # We don't need to tie the root experiment_dir to any specific worker's |
| 278 | + # CNS2 cell, as it's just the parent for the trial directories. |
| 279 | + # The trial directories themselves will get the correct cell placement. |
| 280 | + kwargs = {} |
| 281 | + makedirs(experiment_dir, mode=0o775, **kwargs) |
| 282 | + trial_dir = _get_trial_dir(experiment_dir, worker_id) |
| 283 | + makedirs(trial_dir, mode=0o775) |
269 | 284 | log_path = os.path.join( |
270 | | - log_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index())) |
| 285 | + trial_dir, 'worker{}_{}.log'.format(worker_id, jax.process_index()) |
| 286 | + ) |
271 | 287 | with gfile.GFile(log_path, 'a') as logfile: |
272 | 288 | utils.add_log_file(logfile) |
273 | 289 | if jax.process_index() == 0: |
|
0 commit comments