From 47880ebbeac32354d68d462bff28f2e6a1942bd2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 30 Mar 2026 16:28:30 -0700 Subject: [PATCH] fix for callback config flags. PiperOrigin-RevId: 891972531 --- init2winit/trainer_lib/base_trainer.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 456dfe7d..31b57966 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -31,6 +31,7 @@ from init2winit.trainer_lib import training_algorithm from init2winit.training_metrics_grabber import make_training_metrics import jax +from ml_collections.config_dict import config_dict import orbax.checkpoint as ocp @@ -192,7 +193,11 @@ def __init__( self._xm_work_unit = None if callback_configs is None: self._callback_configs = [] - elif isinstance(callback_configs, dict): + elif isinstance(callback_configs, (dict, config_dict.ConfigDict)): + # Wrap a single callback config dict (or ConfigDict) into a list. + # Workload definitions that set callback_configs as a bare dict have it + # automatically promoted to a ConfigDict by ml_collections, so we must + # handle both types here. self._callback_configs = [callback_configs] else: self._callback_configs = callback_configs