Skip to content

Commit b973e7d

Browse files
priyakasimbegcopybara-github
authored andcommitted
fix for callback config flags.
PiperOrigin-RevId: 891892381
1 parent 202d303 commit b973e7d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

init2winit/trainer_lib/base_trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from init2winit.trainer_lib import training_algorithm
3232
from init2winit.training_metrics_grabber import make_training_metrics
3333
import jax
34+
from ml_collections.config_dict import config_dict
3435
import orbax.checkpoint as ocp
3536

3637

@@ -192,7 +193,11 @@ def __init__(
192193
self._xm_work_unit = None
193194
if callback_configs is None:
194195
self._callback_configs = []
195-
elif isinstance(callback_configs, dict):
196+
elif isinstance(callback_configs, (dict, config_dict.ConfigDict)):
197+
# Wrap a single callback config dict (or ConfigDict) into a list.
198+
# Workload definitions that set callback_configs as a bare dict have it
199+
# automatically promoted to a ConfigDict by ml_collections, so we must
200+
# handle both types here.
196201
self._callback_configs = [callback_configs]
197202
else:
198203
self._callback_configs = callback_configs

0 commit comments

Comments
 (0)