diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 456dfe7d..31b57966 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -31,6 +31,7 @@ from init2winit.trainer_lib import training_algorithm from init2winit.training_metrics_grabber import make_training_metrics import jax +from ml_collections.config_dict import config_dict import orbax.checkpoint as ocp @@ -192,7 +193,11 @@ def __init__( self._xm_work_unit = None if callback_configs is None: self._callback_configs = [] - elif isinstance(callback_configs, dict): + elif isinstance(callback_configs, (dict, config_dict.ConfigDict)): + # Wrap a single callback config dict (or ConfigDict) into a list. + # Workload definitions that set callback_configs as a bare dict have it + # automatically promoted to a ConfigDict by ml_collections, so we must + # handle both types here. self._callback_configs = [callback_configs] else: self._callback_configs = callback_configs