Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,10 @@ def compute_scores_for_aux_loss_pytorch(
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float())
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
elif score_function == "sqrtsoftplus":
scores = torch.nn.functional.softplus(logits.float()).sqrt()
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20)
else:
raise ValueError(f"Invalid score_function: {score_function}")

Expand Down Expand Up @@ -324,9 +324,9 @@ def test_topk_softmax(


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_tokens", [2048, 7168])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("topk", [1, 4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function in ("sigmoid", "sqrtsoftplus"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi

__syncwarp(); //Confirm the scores is written to the output

// Sigmoid/Sqrtsoftplus post-processing when topk > 1
// Sigmoid/Sqrtsoftplus post-processing
if (score_function == 0 || score_function == 2) {
if (topk > 1) {
auto sum_logits =
warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] /= (sum_logits + epsilon);
}
auto sum_logits =
warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id);
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
local_logits[i] /= (sum_logits + epsilon);
}
__syncwarp();
}
Expand Down Expand Up @@ -246,8 +244,8 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const CompType *int
__syncwarp();
}

// Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward)
if (topk > 1 && (score_function == 0 || score_function == 2)) {
// Sigmoid/Sqrtsoftplus Post-processing bwd (normalization backward)
if (score_function == 0 || score_function == 2) {
// Select the correct activation output buffer:
// - Sigmoid: local_act_from_fwd already contains sigmoid output
// - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above
Expand Down
69 changes: 58 additions & 11 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def _make_graphed_callables(
pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
) -> SingleOrTuple[Callable]:
"""
Helper method for `make_graphed_callables`
Expand Down Expand Up @@ -440,6 +442,8 @@ def hook_fn(
else:
visited_te_modules[func_idx].update(modules)

if pre_warmup_hook is not None:
pre_warmup_hook()
for warmup_iter in range(num_warmup_iters):
hooks = []
for module in func.modules():
Expand Down Expand Up @@ -507,11 +511,8 @@ def hook_fn(
else:
grad_inputs = None
del outputs, grad_inputs
# The following code is added specifically for MCore's special requirements,
# aimed at preventing warmup from altering the control flow.
for module in func.modules():
if hasattr(module, "is_first_microbatch"):
module.is_first_microbatch = True
if post_warmup_hook is not None:
post_warmup_hook()
torch.cuda.synchronize()

# All captures here share a mempool. To avoid replays corrupting each other's memory,
Expand Down Expand Up @@ -782,14 +783,15 @@ class Graphed(torch.autograd.Function):
"""Autograd function for graph replay."""

@staticmethod
def forward(ctx, skip_fp8_weight_update, *inputs):
def forward(ctx, skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *inputs):
# pylint: disable=missing-function-docstring

# Set flag for whether to update FP8 weight updates
ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module()
if ctx.is_first_module and skip_fp8_weight_update is not None:
FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(skip_fp8_weight_update)

ctx.cuda_graph_stream = cuda_graph_stream
ctx.cuda_graph_event = cuda_graph_event
# Copy values from new tensors into static tensors
for i in range(len_user_args):
if (
Expand All @@ -799,7 +801,16 @@ def forward(ctx, skip_fp8_weight_update, *inputs):
static_input_surface[i].copy_(inputs[i])

# Replay forward graph
fwd_graph.replay()
if cuda_graph_stream != torch.cuda.current_stream():
cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with cuda_graph_stream:
fwd_graph.replay()
if cuda_graph_event is not None:
torch.cuda.current_stream().wait_event(cuda_graph_event)
else:
torch.cuda.current_stream().wait_stream(cuda_graph_stream)
else:
fwd_graph.replay()
assert isinstance(static_outputs, tuple)
return tuple(o.detach() if o is not None else o for o in static_outputs)

Expand All @@ -816,15 +827,24 @@ def backward(ctx, *grads):
# incoming grad is already in the right place
if g.data_ptr() != grad.data_ptr():
g.copy_(grad)
bwd_graph.replay()
if ctx.cuda_graph_stream != torch.cuda.current_stream():
ctx.cuda_graph_stream.wait_stream(torch.cuda.current_stream())
with ctx.cuda_graph_stream:
bwd_graph.replay()
if ctx.cuda_graph_event is not None:
torch.cuda.current_stream().wait_event(ctx.cuda_graph_event)
else:
torch.cuda.current_stream().wait_stream(ctx.cuda_graph_stream)
else:
bwd_graph.replay()

# Update FP8 scale factors if needed
if ctx.is_first_module:
FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False)

# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return (None,) + tuple(
return (None, None, None) + tuple(
b.detach() if b is not None else b for b in static_grad_inputs
)

Expand All @@ -839,6 +859,23 @@ def functionalized(*user_args, **user_kwargs):

skip_fp8_weight_update = not user_kwargs["is_first_microbatch"]

# The cuda_graph_stream and cuda_graph_event are used in the TE CUDA graph replay.
# When replaying the graph in the cuda graph stream, the graph replay could overlap
# with the work on main stream.
# When cuda_graph_event is given, it should be an external event recorded
# in the cuda graph and is used to sync-back to the main stream.
# If cuda_graph_event is not given, it will be None and the graph replay will block
# the main stream until it is finished.
if "cuda_graph_stream" in user_kwargs:
cuda_graph_stream = user_kwargs["cuda_graph_stream"]
user_kwargs.pop("cuda_graph_stream")
else:
cuda_graph_stream = torch.cuda.current_stream()
if "cuda_graph_event" in user_kwargs:
cuda_graph_event = user_kwargs["cuda_graph_event"]
user_kwargs.pop("cuda_graph_event")
else:
cuda_graph_event = None
# Check that required kwargs are provided
for key in kwargs_keys:
if key not in user_kwargs:
Expand All @@ -854,7 +891,9 @@ def functionalized(*user_args, **user_kwargs):
flatten_user_args, _ = _tree_flatten(user_args)
flatten_user_kwargs, _ = _tree_flatten([user_kwargs[key] for key in kwargs_keys])
func_args = tuple(flatten_user_args) + tuple(flatten_user_kwargs) + module_params
out = Graphed.apply(skip_fp8_weight_update, *func_args)
out = Graphed.apply(
skip_fp8_weight_update, cuda_graph_stream, cuda_graph_event, *func_args
)
return _tree_unflatten(out, output_unflatten_spec)

return functionalized
Expand Down Expand Up @@ -1040,6 +1079,8 @@ def make_graphed_callables(
pool: Optional[Tuple[int, ...]] = None,
retain_graph_in_backward: bool = False,
_reuse_graph_input_output_buffers: bool = False,
pre_warmup_hook: Optional[Callable] = None,
post_warmup_hook: Optional[Callable] = None,
) -> Union[Callable, Tuple[Callable, ...]]:
"""
Make CUDA graph version of Transformer Engine modules
Expand Down Expand Up @@ -1078,6 +1119,10 @@ def make_graphed_callables(
graphs. Only supported with Mcore interleaved pipeline parallelism, i.e.
when `_order` is provided. All callables in `modules` are assumed to have
inputs and outputs with the same dtype and shape.
pre_warmup_hook: callable, default = None
A hook function that will be called before the warmup iterations.
post_warmup_hook: callable, default = None
A hook function that will be called after the warmup iterations.

Quantization parameters
-----------------------
Expand Down Expand Up @@ -1264,6 +1309,8 @@ def call_func(self, *args, **kwargs):
pool=pool,
retain_graph_in_backward=retain_graph_in_backward,
_reuse_graph_input_output_buffers=_reuse_graph_input_output_buffers,
pre_warmup_hook=pre_warmup_hook,
post_warmup_hook=post_warmup_hook,
)

# Ensures warmup does not affect numerics for ops such as dropout.
Expand Down
Loading