File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -173,8 +173,8 @@ def save_unreplicated_checkpoint(
173173 # So we first all_gather it to the host and then call jax.device_get
174174 if jax .process_count () > 1 :
175175 unreplicated_optimizer_state = jax .device_get (
176- process_allgather (optimizer_state ))
177- unreplicated_params = jax .device_get (process_allgather (params ))
176+ process_allgather (optimizer_state , tiled = True ))
177+ unreplicated_params = jax .device_get (process_allgather (params , tiled = True ))
178178 else :
179179 unreplicated_optimizer_state = jax .device_get (optimizer_state )
180180 unreplicated_params = jax .device_get (params )
Original file line number Diff line number Diff line change @@ -191,8 +191,8 @@ def evaluate(
191191 # `merge` aggregates the metrics across batches.
192192 metrics = metrics .merge (computed_metrics )
193193
194- metrics = jax .device_get (process_allgather (metrics ))
195- metrics = jax .tree_util .tree_map (lambda x : x [0 ] if x .ndim > 0 else x , metrics )
194+ metrics = jax .device_get (process_allgather (metrics , tiled = True ))
195+ metrics = jax .tree_util .tree_map (lambda x : x [0 ] if x .ndim > 1 else x , metrics )
196196 # For data splits with no data (e.g. Imagenet no test set) no values
197197 # will appear for that split.
198198 if metrics is not None :
You can’t perform that action at this time.
0 commit comments