Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 35 additions & 13 deletions megatron/training/global_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,33 +188,55 @@ def _set_tensorboard_writer(args):


def _set_wandb_writer(args):
"""
Minimal-fix version: sanitize only the fields that are known to break JSON
(bytes, tensor, callable) before passing to wandb.
"""
global _GLOBAL_WANDB_WRITER
_ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER,
'wandb writer')
_ensure_var_is_not_initialized(_GLOBAL_WANDB_WRITER, 'wandb writer')

if getattr(args, 'wandb_project', '') and args.rank == (args.world_size - 1):
if args.wandb_exp_name == '':
raise ValueError("Please specify the wandb experiment name!")

import wandb
if args.wandb_save_dir:
save_dir = args.wandb_save_dir
else:
# Defaults to the save dir.
save_dir = os.path.join(args.save, 'wandb')
wandb_config = vars(args)
import wandb, json, os
from argparse import Namespace

# --- minimal sanitizer ----------------------------------------------
def _clean(obj):
if isinstance(obj, Namespace):
obj = vars(obj)
if isinstance(obj, dict):
return {k: _clean(v) for k, v in obj.items()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: bytes values inside dicts are silently dropped, not converted.

The dict comprehension filters out bytes with not isinstance(v, (bytes, ...)), so they are excluded entirely. The if isinstance(obj, bytes) branch below is only reachable for a top-level bytes value, not one nested in a dict. This means bytes dict values are lost rather than converted to strings.

Suggested change
return {k: _clean(v) for k, v in obj.items()
if isinstance(obj, dict):
result = {}
for k, v in obj.items():
if isinstance(v, (type, type(lambda: None))):
continue
result[k] = _clean(v)
return result

Then the if isinstance(obj, bytes) branch will properly handle bytes values when they are recursively processed as dict values.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type(lambda: None) only catches FunctionType, missing many callable types.

This check misses built-in functions (len, print), bound methods, functools.partial, and any class instance with __call__. Use callable() instead:

Suggested change
return {k: _clean(v) for k, v in obj.items()
if isinstance(v, (bytes, type)) or callable(v):

if not isinstance(v, (bytes, type, type(lambda: None)))}
if isinstance(obj, bytes):
return obj.decode('utf-8', errors='ignore')
if hasattr(obj, 'tolist'): # torch.Tensor / numpy.ndarray
return obj.tolist()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lists are not recursively sanitized.

If args contains a list with tensors, numpy arrays, or other non-serializable objects, they pass through untouched and will still fail JSON serialization. A list branch is needed:

Suggested change
return obj.tolist()
if isinstance(obj, list):
return [_clean(v) for v in obj]
if hasattr(obj, 'tolist'): # torch.Tensor / numpy.ndarray

try: # final JSON safety check
json.dumps(obj)
return obj
except (TypeError, ValueError):
return repr(obj) # last resort: stringify
# ---------------------------------------------------------------------

wandb_config = _clean(args)

if 'kitchen_config_file' in wandb_config and wandb_config['kitchen_config_file'] is not None:
# Log the contents of the config for discovery of what the quantization
# settings were.
with open(wandb_config['kitchen_config_file'], "r") as f:
wandb_config['kitchen_config_file_contents'] = f.read()

save_dir = args.wandb_save_dir or os.path.join(args.save, 'wandb')
os.makedirs(save_dir, exist_ok=True)

wandb_kwargs = {
'dir': save_dir,
'name': args.wandb_exp_name,
'project': args.wandb_project,
'config': wandb_config}
if args.wandb_entity:
if getattr(args, 'wandb_entity', None):
wandb_kwargs['entity'] = args.wandb_entity
os.makedirs(wandb_kwargs['dir'], exist_ok=True)

wandb.init(**wandb_kwargs)
_GLOBAL_WANDB_WRITER = wandb

Expand Down