Skip to content

Commit e7a59fe

Browse files
priyakasimbegcopybara-github
authored andcommitted
internal change
PiperOrigin-RevId: 814399303
1 parent eeec743 commit e7a59fe

2 files changed

Lines changed: 4 additions & 4 deletions

File tree

init2winit/checkpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ def save_unreplicated_checkpoint(
173173
# So we first all_gather it to the host and then call jax.device_get
174174
if jax.process_count() > 1:
175175
unreplicated_optimizer_state = jax.device_get(
176-
process_allgather(optimizer_state))
177-
unreplicated_params = jax.device_get(process_allgather(params))
176+
process_allgather(optimizer_state, tiled=True))
177+
unreplicated_params = jax.device_get(process_allgather(params, tiled=True))
178178
else:
179179
unreplicated_optimizer_state = jax.device_get(optimizer_state)
180180
unreplicated_params = jax.device_get(params)

init2winit/trainer_lib/trainer_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def evaluate(
191191
# `merge` aggregates the metrics across batches.
192192
metrics = metrics.merge(computed_metrics)
193193

194-
metrics = jax.device_get(process_allgather(metrics))
195-
metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 0 else x, metrics)
194+
metrics = jax.device_get(process_allgather(metrics, tiled=True))
195+
metrics = jax.tree_util.tree_map(lambda x: x[0] if x.ndim > 1 else x, metrics)
196196
# For data splits with no data (e.g. Imagenet no test set) no values
197197
# will appear for that split.
198198
if metrics is not None:

0 commit comments

Comments
 (0)