diff --git a/megatron/training/global_vars.py b/megatron/training/global_vars.py index a3d41f5e45a..9a9100e7b01 100644 --- a/megatron/training/global_vars.py +++ b/megatron/training/global_vars.py @@ -188,33 +188,55 @@ def _set_tensorboard_writer(args): def _set_wandb_writer(args): + """ + Minimal-fix version: sanitize only the fields that are known to break JSON + (bytes, tensor, callable) before passing to wandb. + """ global _GLOBAL_WANDB_WRITER - _ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER, - 'wandb writer') + _ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER, 'wandb writer') + if getattr(args, 'wandb_project', '') and args.rank == (args.world_size - 1): if args.wandb_exp_name == '': raise ValueError("Please specify the wandb experiment name!") - import wandb - if args.wandb_save_dir: - save_dir = args.wandb_save_dir - else: - # Defaults to the save dir. - save_dir = os.path.join(args.save, 'wandb') - wandb_config = vars(args) + import wandb, json, os + from argparse import Namespace + + # --- minimal sanitizer ---------------------------------------------- + def _clean(obj): + if isinstance(obj, Namespace): + obj = vars(obj) + if isinstance(obj, dict): + return {k: _clean(v) for k, v in obj.items() + if not isinstance(v, (bytes, type, type(lambda: None)))} + if isinstance(obj, bytes): + return obj.decode('utf-8', errors='ignore') + if hasattr(obj, 'tolist'): # torch.Tensor / numpy.ndarray + return obj.tolist() + try: # final JSON safety check + json.dumps(obj) + return obj + except (TypeError, ValueError): + return repr(obj) # last resort: stringify + # --------------------------------------------------------------------- + + wandb_config = _clean(args) + if 'kitchen_config_file' in wandb_config and wandb_config['kitchen_config_file'] is not None: - # Log the contents of the config for discovery of what the quantization - # settings were. with open(wandb_config['kitchen_config_file'], "r") as f: wandb_config['kitchen_config_file_contents'] = f.read() + + save_dir = args.wandb_save_dir or os.path.join(args.save, 'wandb') + os.makedirs(save_dir, exist_ok=True) + wandb_kwargs = { 'dir': save_dir, 'name': args.wandb_exp_name, 'project': args.wandb_project, 'config': wandb_config} - if args.wandb_entity: + if getattr(args, 'wandb_entity', None): wandb_kwargs['entity'] = args.wandb_entity - os.makedirs(wandb_kwargs['dir'], exist_ok=True) + wandb.init(**wandb_kwargs) _GLOBAL_WANDB_WRITER = wandb