File tree Expand file tree Collapse file tree 1 file changed +6
-1
lines changed
Expand file tree Collapse file tree 1 file changed +6
-1
lines changed Original file line number Diff line number Diff line change 3131from init2winit .trainer_lib import training_algorithm
3232from init2winit .training_metrics_grabber import make_training_metrics
3333import jax
34+ from ml_collections .config_dict import config_dict
3435import orbax .checkpoint as ocp
3536
3637
@@ -192,7 +193,11 @@ def __init__(
192193 self ._xm_work_unit = None
193194 if callback_configs is None :
194195 self ._callback_configs = []
195- elif isinstance (callback_configs , dict ):
196+ elif isinstance (callback_configs , (dict , config_dict .ConfigDict )):
197+ # Wrap a single callback config dict (or ConfigDict) into a list.
198+ # Workload definitions that set callback_configs as a bare dict have it
199+ # automatically promoted to a ConfigDict by ml_collections, so we must
200+ # handle both types here.
196201 self ._callback_configs = [callback_configs ]
197202 else :
198203 self ._callback_configs = callback_configs
You can’t perform that action at this time.
0 commit comments