diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index d691c1828c..c741ad6353 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -21,7 +21,13 @@ def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGr reduce_within_microbatch = tensor_name != "weight" if tensor_name == "weight": if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group + # Do not overwrite with `None`: in torch.distributed collectives + # group=None means the default/world process group. + if tp_group is not None: + reduction_group = tp_group + else: + # "Reduce in TP group" requested, but TP group is missing. + skip_reduction = True else: skip_reduction = True return skip_reduction, reduction_group, reduce_within_microbatch