diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index 64000e109e..36c09060ed 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -113,10 +113,10 @@ def compute_scores_for_aux_loss_pytorch( scores = torch.softmax(logits, dim=-1, dtype=torch.float32) elif score_function == "sigmoid": scores = torch.sigmoid(logits.float()) - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) elif score_function == "sqrtsoftplus": scores = torch.nn.functional.softplus(logits.float()).sqrt() - scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) else: raise ValueError(f"Invalid score_function: {score_function}") @@ -324,9 +324,9 @@ def test_topk_softmax( @pytest.mark.parametrize("dtype", [torch.float32]) -@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) +@pytest.mark.parametrize("num_tokens", [2048, 7168]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) -@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("topk", [1, 4, 8]) @pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): if score_function in ("sigmoid", "sqrtsoftplus"): diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 4f405e0a25..d38fcde6bf 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -109,14 +109,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi __syncwarp(); //Confirm the scores is written to the output - // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + // Sigmoid/Sqrtsoftplus post-processing if (score_function == 0 || score_function == 2) { - if (topk > 1) { - auto sum_logits = - warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); - for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] /= (sum_logits + epsilon); - } + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_logits[i] /= (sum_logits + epsilon); } __syncwarp(); } @@ -246,8 +244,8 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int __syncwarp(); } - // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) - if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Sigmoid/Sqrtsoftplus Post-processing bwd (normalization backward) + if (score_function == 0 || score_function == 2) { // Select the correct activation output buffer: // - Sigmoid: local_act_from_fwd already contains sigmoid output // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f4b1fb23ae..bae911b4e1 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -108,6 +108,8 @@ def _make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> SingleOrTuple[Callable]: """ Helper method for `make_graphed_callables` @@ -440,6 +442,8 @@ def hook_fn( else: visited_te_modules[func_idx].update(modules) + if pre_warmup_hook is not None: + pre_warmup_hook() for warmup_iter in range(num_warmup_iters): hooks = [] for module in func.modules(): @@ -507,11 +511,8 @@ def hook_fn( else: grad_inputs = None del outputs, grad_inputs - # The following code is added specifically for MCore's special requirements, - # aimed at preventing warmup from altering the control flow. - for module in func.modules(): - if hasattr(module, "is_first_microbatch"): - module.is_first_microbatch = True + if post_warmup_hook is not None: + post_warmup_hook() torch.cuda.synchronize() # All captures here share a mempool. To avoid replays corrupting each other's memory, @@ -782,14 +783,15 @@ class Graphed(torch.autograd.Function): """Autograd function for graph replay.""" @staticmethod - def forward(ctx, skip_fp8_weight_update, *inputs): + def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs): # pylint: disable=missing-function-docstring # Set flag for whether to update FP8 weight updates ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() if ctx.is_first_module and skip_fp8_weight_update is not None: FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update) - + ctx.cuda_graph_stream = cuda_graph_stream + ctx.cuda_graph_event = cuda_graph_event # Copy values from new tensors into static tensors for i in range(len_user_args): if ( @@ -799,7 +801,16 @@ def forward(ctx, skip_fp8_weight_update, *inputs): static_input_surface[i].copy_(inputs[i]) # Replay forward graph - fwd_graph.replay() + if cuda_graph_stream != torch.cuda.current_stream(): + cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with cuda_graph_stream: + fwd_graph.replay() + if cuda_graph_event is not None: + torch.cuda.current_stream().wait_event(cuda_graph_event) + else: + torch.cuda.current_stream().wait_stream(cuda_graph_stream) + else: + fwd_graph.replay() assert isinstance(static_outputs, tuple) return tuple(o.detach() if o is not None else o for o in static_outputs) @@ -816,7 +827,16 @@ def backward(ctx, *grads): # incoming grad is already in the right place if g.data_ptr() != grad.data_ptr(): g.copy_(grad) - bwd_graph.replay() + if ctx.cuda_graph_stream != torch.cuda.current_stream(): + ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream()) + with ctx.cuda_graph_stream: + bwd_graph.replay() + if ctx.cuda_graph_event is not None: + torch.cuda.current_stream().wait_event(ctx.cuda_graph_event) + else: + torch.cuda.current_stream().wait_stream(ctx.cuda_graph_stream) + else: + bwd_graph.replay() # Update FP8 scale factors if needed if ctx.is_first_module: @@ -824,7 +844,7 @@ def backward(ctx, *grads): # Input args that didn't require grad expect a None gradient. assert isinstance(static_grad_inputs, tuple) - return (None,) + tuple( + return (None, None, None) + tuple( b.detach() if b is not None else b for b in static_grad_inputs ) @@ -839,6 +859,23 @@ def functionalized(*user_args, **user_kwargs): skip_fp8_weight_update = not user_kwargs["is_first_microbatch"] + # The cuda_graph_stream and cuda_graph_event are used in the TE CUDA graph replay. + # When replaying the graph in the cuda graph stream, the graph replay could overlap + # with the work on main stream. + # When cuda_graph_event is given, it should be an external event recorded + # in the cuda graph and is used to sync-back to the main stream. + # If cuda_graph_event is not given, it will be None and the graph replay will block + # the main stream until it is finished. + if "cuda_graph_stream" in user_kwargs: + cuda_graph_stream = user_kwargs["cuda_graph_stream"] + user_kwargs.pop("cuda_graph_stream") + else: + cuda_graph_stream = torch.cuda.current_stream() + if "cuda_graph_event" in user_kwargs: + cuda_graph_event = user_kwargs["cuda_graph_event"] + user_kwargs.pop("cuda_graph_event") + else: + cuda_graph_event = None # Check that required kwargs are provided for key in kwargs_keys: if key not in user_kwargs: @@ -854,7 +891,9 @@ def functionalized(*user_args, **user_kwargs): flatten_user_args, _ = _tree_flatten(user_args) flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys]) func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params - out = Graphed.apply(skip_fp8_weight_update, *func_args) + out = Graphed.apply( + skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args + ) return _tree_unflatten(out, output_unflatten_spec) return functionalized @@ -1040,6 +1079,8 @@ def make_graphed_callables( pool: Optional[Tuple[int, ...]] = None, retain_graph_in_backward: bool = False, _reuse_graph_input_output_buffers: bool = False, + pre_warmup_hook: Optional[Callable] = None, + post_warmup_hook: Optional[Callable] = None, ) -> Union[Callable, Tuple[Callable, ...]]: """ Make CUDA graph version of Transformer Engine modules @@ -1078,6 +1119,10 @@ def make_graphed_callables( graphs. Only supported with Mcore interleaved pipeline parallelism, i.e. when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. + pre_warmup_hook: callable, default = None + A hook function that will be called before the warmup iterations. + post_warmup_hook: callable, default = None + A hook function that will be called after the warmup iterations. Quantization parameters ----------------------- @@ -1264,6 +1309,8 @@ def call_func(self, *args, **kwargs): pool=pool, retain_graph_in_backward=retain_graph_in_backward, _reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers, + pre_warmup_hook=pre_warmup_hook, + post_warmup_hook=post_warmup_hook, ) # Ensures warmup does not affect numerics for ops such as dropout.