From f508e662f0b038ea8c97f8ceb2304432a5be5c97 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Mon, 2 Mar 2026 13:30:43 +0800 Subject: [PATCH 1/3] [PyTorch] Remove `is_first_microbatch` setting after cudagraph warmup (#2715) Remove is_first_microbatch setting after warmup Signed-off-by: Robin Zhang --- transformer_engine/pytorch/graph.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f4b1fb23ae..d3320fd70f 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -507,11 +507,6 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs - # The following code is added specifically for MCore's special requirements, - # aimed at preventing warmup from altering the control flow. - for module in func.modules(): - if hasattr(module, "is_first_microbatch"): - module.is_first_microbatch = True torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, From 537f134236d10cf50a3bb296f453301079b5d7d5 Mon Sep 17 00:00:00 2001 From: Tong Liu Date: Mon, 2 Mar 2026 16:16:10 +0800 Subject: [PATCH 2/3] [Common][PyTorch] Fix normalization for `fused_score_for_moe_aux_loss` (#2720) * fix topk=1 Signed-off-by: tongliu * add topk=1 ut Signed-off-by: tongliu --------- Signed-off-by: tongliu --- tests/pytorch/test_fused_router.py | 8 ++++---- .../fused_router/fused_score_for_moe_aux_loss.cu | 16 +++++++--------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 64000e109e..36c09060ed 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -113,10 +113,10 @@ def compute_scores_for_aux_loss_pytorch( scores = torch.softmax(logits, dim=-1, dtype=torch.float32) elif score_function == "sigmoid": scores = torch.sigmoid(logits.float()) - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) elif score_function == "sqrtsoftplus": scores = torch.nn.functional.softplus(logits.float()).sqrt() - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) else: raise ValueError(f"Invalid score_function: {score_function}") @@ -324,9 +324,9 @@ def test_topk_softmax( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_tokens", [2048, 7168]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) -@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("topk", [1, 4, 8]) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): if score_function in ("sigmoid", "sqrtsoftplus"): diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 4f405e0a25..d38fcde6bf 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -109,14 +109,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi __syncwarp(); //Confirm the scores is written to the output - // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + // Sigmoid/Sqrtsoftplus post-processing if (score_function == 0 || score_function == 2) { - if (topk > 1) { - auto sum_logits = - warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] /= (sum_logits + epsilon); - } + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] /= (sum_logits + epsilon); } __syncwarp(); } @@ -246,8 +244,8 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int __syncwarp(); } - // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Sigmoid/Sqrtsoftplus Post-processing bwd (normalization backward) + if (score_function == 0 || score_function == 2) { // Select the correct activation output buffer: // - Sigmoid: local_act_from_fwd already contains sigmoid output // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above From bba7bf6a4101e150d1aaf9278608385214d09684 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Mon, 2 Mar 2026 16:39:41 +0800 Subject: [PATCH 3/3] [PyTorch] Support cuda graph capturing offloading module (#2435) * support cuda graph capture offloading module Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove reset_hook and init_chunk_handler_hook Signed-off-by: Hongbin Liu * remove reset_hook and init_chunk_handler_hook Signed-off-by: Hongbin Liu * minor fix Signed-off-by: root * temp fix overlap-grad-reduce Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reuse mark_not_offload() and do not offload scale_inv Signed-off-by: Hongbin Liu * temp fix for mxfp8 Signed-off-by: Hongbin Liu * fix bug for record_stream and from_blob Signed-off-by: Hongbin Liu * disable offloading core_attn_out and refine cpu overhead of at::empty Signed-off-by: Hongbin Liu * minor fix Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * return ptr of whole buffer and offload the whole buffer Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code revie Signed-off-by: Hongbin Liu * remove code changes of offloading and quantizer Signed-off-by: Hongbin Liu * minor fix Signed-off-by: Hongbin Liu * minor fix Signed-off-by: Hongbin Liu * minor fix Signed-off-by: Hongbin Liu * minor fix Signed-off-by: Hongbin Liu * minor fix Signed-off-by: Hongbin Liu * add docstring Signed-off-by: Hongbin Liu --------- Signed-off-by: Hongbin Liu Signed-off-by: root Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root Co-authored-by: root --- transformer_engine/pytorch/graph.py | 64 ++++++++++++++++++++++++++--- 1 file changed, 58 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index d3320fd70f..bae911b4e1 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -108,6 +108,8 @@ def _make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -440,6 +442,8 @@ def hook_fn( else: visited_te_modules[func_idx].update(modules) + if pre_warmup_hook is not None: + pre_warmup_hook() for warmup_iter in range(num_warmup_iters): hooks = [] for module in func.modules(): @@ -507,6 +511,8 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs + if post_warmup_hook is not None: + post_warmup_hook() torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -777,14 +783,15 @@ class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @staticmethod - def forward(ctx, skip_fp8_weight_update, *inputs): + def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs): # pylint: disable=missing-function-docstring # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) - + ctx.cuda_graph_stream = cuda_graph_stream + ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors for i in range(len_user_args): if ( @@ -794,7 +801,16 @@ def forward(ctx, skip_fp8_weight_update, *inputs): static_input_surface[i].copy_(inputs[i]) # Replay forward graph - fwd_graph.replay() + if cuda_graph_stream != torch.cuda.current_stream(): + cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with cuda_graph_stream: + fwd_graph.replay() + if cuda_graph_event is not None: + torch.cuda.current_stream().wait_event(cuda_graph_event) + else: + torch.cuda.current_stream().wait_stream(cuda_graph_stream) + else: + fwd_graph.replay() assert isinstance(static_outputs, tuple) return tuple(o.detach() if o is not None else o for o in static_outputs) @@ -811,7 +827,16 @@ def backward(ctx, *grads): # incoming grad is already in the right place if g.data_ptr() != grad.data_ptr(): g.copy_(grad) - bwd_graph.replay() + if ctx.cuda_graph_stream != torch.cuda.current_stream(): + ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with ctx.cuda_graph_stream: + bwd_graph.replay() + if ctx.cuda_graph_event is not None: + torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) + else: + torch.cuda.current_stream().wait_stream(ctx.cuda_graph_stream) + else: + bwd_graph.replay() # Update FP8 scale factors if needed if ctx.is_first_module: @@ -819,7 +844,7 @@ def backward(ctx, *grads): # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) - return (None,) + tuple( + return (None, None, None) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) @@ -834,6 +859,23 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + # The cuda_graph_stream and cuda_graph_event are used in the TE CUDA graph replay. + # When replaying the graph in the cuda graph stream, the graph replay could overlap + # with the work on main stream. + # When cuda_graph_event is given, it should be an external event recorded + # in the cuda graph and is used to sync-back to the main stream. + # If cuda_graph_event is not given, it will be None and the graph replay will block + # the main stream until it is finished. + if "cuda_graph_stream" in user_kwargs: + cuda_graph_stream = user_kwargs["cuda_graph_stream"] + user_kwargs.pop("cuda_graph_stream") + else: + cuda_graph_stream = torch.cuda.current_stream() + if "cuda_graph_event" in user_kwargs: + cuda_graph_event = user_kwargs["cuda_graph_event"] + user_kwargs.pop("cuda_graph_event") + else: + cuda_graph_event = None # Check that required kwargs are provided for key in kwargs_keys: if key not in user_kwargs: @@ -849,7 +891,9 @@ def functionalized(*user_args, **user_kwargs): flatten_user_args, _ = _tree_flatten(user_args) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params - out = Graphed.apply(skip_fp8_weight_update, *func_args) + out = Graphed.apply( + skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args + ) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -1035,6 +1079,8 @@ def make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1073,6 +1119,10 @@ def make_graphed_callables( graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. + pre_warmup_hook: callable, default = None + A hook function that will be called before the warmup iterations. + post_warmup_hook: callable, default = None + A hook function that will be called after the warmup iterations. Quantization parameters ----------------------- @@ -1259,6 +1309,8 @@ def call_func(self, *args, **kwargs): pool=pool, retain_graph_in_backward=retain_graph_in_backward, _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, + pre_warmup_hook=pre_warmup_hook, + post_warmup_hook=post_warmup_hook, ) # Ensures warmup does not affect numerics for ops such as dropout.