diff --git a/tests/pytorch/test_linear_cross_entropy.py b/tests/pytorch/test_linear_cross_entropy.py new file mode 100644 index 0000000000..07aed2d067 --- /dev/null +++ b/tests/pytorch/test_linear_cross_entropy.py @@ -0,0 +1,1296 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import contextlib +import os +import typing +from contextlib import ExitStack + +import numpy as np +import pytest +import torch +import torch.distributed as dist + +from transformer_engine.pytorch import linear_cross_entropy + +@pytest.mark.skipif( + "WORLD_SIZE" in os.environ and os.environ["WORLD_SIZE"] != "1", reason="Requires single GPU" +) +class TestFusedLinearCrossEntropyDataParallel: + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + + gc.collect() + torch.cuda.synchronize() + + @staticmethod + def torch_linear_cross_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: str, + ignore_index: int, + ): + # NOTE: need to convert to fp32 to fp32 accumulation, + # thus assure accuracy + logits = hidden.to(torch.float32) @ weight.T.to(torch.float32) + logprobs = torch.nn.functional.cross_entropy( + logits.view(-1, logits.shape[-1]), + labels.view(-1), + reduction=reduction, + ignore_index=ignore_index, + ) + return logprobs.to(torch.float32) + + @staticmethod + def get_problems(): + return [ + (80, 125, 64), + (80, 152064, 64), + (1024, 152064, 4096), + (4096, 152063, 8192), + ((1, 4096), 152064, 8192), + ((2, 4096), 152064, 8192), + ] + + @staticmethod + def get_ignore_index(): + return [-100, 4] + + def test_kernel_launch(self): + """ + Check if the compiled kernel can be + launched with different problem sizes + """ + self.cleanup() + + num_tokens = [15, 26, 128, 513, 2048, 8192] + vocab_size = 152064 + dim = 4096 + dtype = torch.bfloat16 + reduction = "mean" + ignore_index = -100 + + weight = torch.randn(vocab_size, dim, dtype=dtype, device="cuda").requires_grad_() + for num_token in num_tokens: + hidden = torch.randn(num_token, dim, dtype=dtype, device="cuda").requires_grad_() + labels = torch.randint(0, vocab_size, (num_token,), dtype=torch.long, device="cuda") + + logprobs = linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + assert not torch.isnan(logprobs).any() + + gLogprobs = torch.randn_like(logprobs) + (d_hidden, d_weight) = torch.autograd.grad( + (logprobs,), (hidden, weight), (gLogprobs,), retain_graph=False + ) + assert not torch.isnan(d_hidden).any() + assert not torch.isnan(d_weight).any() + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("problem", get_problems()) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("ignore_index", get_ignore_index()) + def test_correctness(self, dtype, problem, reduction, ignore_index): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = (num_tokens,) if isinstance(num_tokens, int) else num_tokens + + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + if ignore_index >= 0 and ignore_index < vocabsize: + pad_labels = torch.nn.functional.pad(labels, (0, 1), value=ignore_index) + labels = pad_labels[..., 1:].contiguous() + + # forward + torch_logprobs = self.torch_linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + + custom_logprobs = linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + + torch.testing.assert_close(torch_logprobs, custom_logprobs) + + # backward + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + + (d_torch_hidden, d_torch_weight) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + + (d_custom_hidden, d_custom_weight) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + + torch.testing.assert_close(d_torch_hidden, d_custom_hidden, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(d_torch_weight, d_custom_weight, atol=1e-3, rtol=1e-3) + + @pytest.mark.parametrize("problem", [((1, 4096), 129280, 7168)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("reduction", ["mean"]) + @pytest.mark.parametrize("ignore_index", [-100]) + def test_performance(self, problem, dtype, reduction, ignore_index): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = (num_tokens,) if isinstance(num_tokens, int) else num_tokens + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch_fwd_latency = list() + torch_bwd_latency = list() + custom_fwd_latency = list() + custom_bwd_latency = list() + + iterations = 5 + for i in range(iterations): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + if ignore_index >= 0 and ignore_index < vocabsize: + pad_labels = torch.nn.functional.pad(labels, (0, 1), value=ignore_index) + labels = pad_labels[..., 1:].contiguous() + + # -------- forward -------- # + start_event.record() + torch_logprobs = self.torch_linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + end_event.record() + torch.cuda.synchronize() + torch_fwd_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + custom_logprobs = linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + end_event.record() + torch.cuda.synchronize() + custom_fwd_latency.append(start_event.elapsed_time(end_event)) + + # -------- backward -------- # + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + + start_event.record() + (d_torch_hidden, d_torch_weight) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + torch_bwd_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (d_custom_hidden, d_custom_weight) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + custom_bwd_latency.append(start_event.elapsed_time(end_event)) + + # --- remove first latency due to warmup --- # + torch_fwd_latency = torch_fwd_latency[1:] + torch_bwd_latency = torch_bwd_latency[1:] + custom_fwd_latency = custom_fwd_latency[1:] + custom_bwd_latency = custom_bwd_latency[1:] + + print() + print(f"[INFO]: On problem {problem}, dtype {dtype}, reduction {reduction}:") + print( + f"[INFO]: Torch forward latency: {sum(torch_fwd_latency) / len(torch_fwd_latency):.2f} ms" + ) + print( + f"[INFO]: Custom forward latency: {sum(custom_fwd_latency) / len(custom_fwd_latency):.2f} ms" + ) + print( + f"[INFO]: Torch backward latency: {sum(torch_bwd_latency) / len(torch_bwd_latency):.2f} ms" + ) + print( + f"[INFO]: Custom backward latency: {sum(custom_bwd_latency) / len(custom_bwd_latency):.2f} ms" + ) + + @pytest.mark.parametrize("problem", [((1, 4096), 129280, 7168)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("reduction", ["mean"]) + @pytest.mark.parametrize("ignore_index", [-100]) + def test_storage(self, problem, dtype, reduction, ignore_index): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = (num_tokens,) if isinstance(num_tokens, int) else num_tokens + print() + print(f"[INFO]: On problem {problem}, dtype {dtype}, reduction {reduction}:") + + def torch_storage(): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + if ignore_index >= 0 and ignore_index < vocabsize: + pad_labels = torch.nn.functional.pad(labels, (0, 1), value=ignore_index) + labels = pad_labels[..., 1:].contiguous() + + torch.cuda.reset_peak_memory_stats() + torch_logprobs = self.torch_linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + print(f"[INFO]: Torch Forward pass peak memory: {torch_max_memory:.2f} MB") + + torch.cuda.reset_peak_memory_stats() + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + (d_torch_hidden, d_torch_weight) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.cuda.synchronize() + torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + print(f"[INFO]: Torch Backward pass peak memory: {torch_backward_max_memory:.2f} MB") + + def custom_storage(): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + if ignore_index >= 0 and ignore_index < vocabsize: + pad_labels = torch.nn.functional.pad(labels, (0, 1), value=ignore_index) + labels = pad_labels[..., 1:].contiguous() + + torch.cuda.reset_peak_memory_stats() + custom_logprobs = linear_cross_entropy( + hidden, weight, labels, reduction=reduction, ignore_index=ignore_index + ) + torch.cuda.synchronize() + custom_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + print(f"[INFO]: Custom Forward pass peak memory: {custom_max_memory:.2f} MB") + + torch.cuda.reset_peak_memory_stats() + g_logprobs = torch.empty_like(custom_logprobs).uniform_(-0.1, 0.1) + (d_custom_hidden, d_custom_weight) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.cuda.synchronize() + custom_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + print(f"[INFO]: Custom Backward pass peak memory: {custom_backward_max_memory:.2f} MB") + + self.cleanup() + torch_storage() + self.cleanup() + custom_storage() + + +@pytest.mark.skipif( + ("WORLD_SIZE" not in os.environ or int(os.environ["WORLD_SIZE"]) < 2), # or True, + reason="Requires torchrun with multiple GPUs", +) +class TestFusedLinearCrossEntropyTensorParallel: + @classmethod + def setup_class(cls): + if dist.is_initialized(): + cls.must_teardown = False + else: + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=int(os.environ["WORLD_SIZE"]), + rank=int(os.environ["RANK"]), + ) + cls.must_teardown = True + cls.tp_group = dist.group.WORLD + + cls.tp_rank = dist.get_rank(cls.tp_group) + cls.tp_world_size = dist.get_world_size(cls.tp_group) + cls.is_chief = cls.tp_rank == 0 + device = torch.device(f"cuda:{cls.tp_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: TP rank: {cls.tp_rank}, TP world size: {cls.tp_world_size}") + + @classmethod + def teardown_class(cls): + if cls.must_teardown: + dist.destroy_process_group() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + + gc.collect() + torch.cuda.synchronize() + + @staticmethod + def torch_linear_cross_entropy_single_gpu( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[str] = "mean", + ): + logits = hidden.to(torch.float32) @ weight.T.to(torch.float32) + logprobs = torch.nn.functional.cross_entropy( + logits.view(-1, logits.shape[-1]), labels.view(-1), reduction=reduction + ) + return logprobs.to(torch.float32) + + class TorchLinearCrossEntropy(torch.autograd.Function): + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + reduction: typing.Optional[str] = "mean", + ): + tp_rank = 0 if tp_group is None else torch.distributed.get_rank(tp_group) + tp_world_size = 1 if tp_group is None else torch.distributed.get_world_size(tp_group) + + logits = hidden.to(torch.float32) @ weight.T.to(torch.float32) + + whole_logits = torch.empty( + (logits.shape[0], logits.shape[-1] * tp_world_size), + dtype=logits.dtype, + device=logits.device, + ) + whole_logits_ref = [ + whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]] + for i in range(tp_world_size) + ] + dist.all_gather(whole_logits_ref, logits, group=tp_group) + + logprobs = torch.nn.functional.cross_entropy( + whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction + ) + + # If we don't preserve whole_logits, + # we need to re-compute it in the backward pass + ctx.save_for_backward(hidden, weight, labels) + ctx.tp_group = tp_group + ctx.reduction = reduction + ctx.tp_rank = tp_rank + ctx.tp_world_size = tp_world_size + + return logprobs.to(torch.float32) + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor): + hidden, weight, labels = ctx.saved_tensors + tp_group = ctx.tp_group + reduction = ctx.reduction + tp_rank = ctx.tp_rank + tp_world_size = ctx.tp_world_size + + num_tokens, dim = hidden.shape + + if reduction == "mean": + _g_logprobs = torch.broadcast_to(g_logprobs / num_tokens, (num_tokens,)) + elif reduction == "sum": + _g_logprobs = torch.broadcast_to(g_logprobs, (num_tokens,)) + else: + _g_logprobs = g_logprobs + + # re-compute whole_logits + logits = hidden.to(torch.float32) @ weight.T.to(torch.float32) + whole_logits = torch.empty( + (logits.shape[0], logits.shape[-1] * tp_world_size), + dtype=logits.dtype, + device=logits.device, + ) + whole_logits_ref = [ + whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]] + for i in range(tp_world_size) + ] + dist.all_gather(whole_logits_ref, logits, group=tp_group) + + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.view(-1).unsqueeze(-1), 1) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + d_logits = (pd - one_hot) * _g_logprobs.unsqueeze(-1) + d_logits = d_logits.to(hidden.dtype) + + local_size = weight.size(0) + local_d_logits = d_logits[:, tp_rank * local_size : (tp_rank + 1) * local_size] + + local_d_hidden = local_d_logits @ weight + local_d_weight = local_d_logits.T @ hidden + + dist.all_reduce(local_d_hidden, op=dist.ReduceOp.SUM, group=tp_group) + + return local_d_hidden, local_d_weight, None, None, None + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + @pytest.mark.parametrize("problem", [(4096, 129280, 8192)]) + def test_torch_tp_vs_single_gpu(self, dtype, reduction, problem): + num_tokens, vocabsize, dim = problem + + hidden = ( + torch.empty((num_tokens, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, (num_tokens,), dtype=torch.long, device="cuda") + + # ------------ forward pass ------------ # + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + # single GPU + whole_weight = torch.empty( + (vocabsize * self.tp_world_size, dim), dtype=dtype, device="cuda" + ) + whole_weight_view = [ + whole_weight[i * vocabsize : (i + 1) * vocabsize, :] for i in range(self.tp_world_size) + ] + dist.all_gather(whole_weight_view, weight, group=self.tp_group) + whole_weight = whole_weight.clone().requires_grad_() + logprobs_single_gpu = self.torch_linear_cross_entropy_single_gpu( + hidden, whole_weight, labels, reduction=reduction + ) + + # TP + logprobs_tp = self.TorchLinearCrossEntropy.apply( + hidden, weight, labels, self.tp_group, reduction + ) + torch.testing.assert_close(logprobs_single_gpu, logprobs_tp) + + # ------------ backward pass ------------ # + g_logprobs = torch.empty_like(logprobs_single_gpu).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + # single GPU + (d_hidden_single_gpu, d_weight_single_gpu) = torch.autograd.grad( + (logprobs_single_gpu,), (hidden, whole_weight), (g_logprobs,), retain_graph=False + ) + + # TP + (d_hidden_tp, d_weight_tp) = torch.autograd.grad( + (logprobs_tp,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.testing.assert_close(d_hidden_single_gpu, d_hidden_tp, atol=1e-3, rtol=1e-3) + local_d_weight_single_gpu = d_weight_single_gpu[ + self.tp_rank * weight.shape[0] : (self.tp_rank + 1) * weight.shape[0], : + ] + torch.testing.assert_close(local_d_weight_single_gpu, d_weight_tp, atol=1e-3, rtol=1e-3) + + @staticmethod + def get_problems(): + return [ + (80, 125, 64), + (80, 152064, 64), + (1024, 152064, 4096), + (4096, 152063, 8192), + ((1, 4096), 152064, 8192), + ((2, 4096), 152064, 8192), + ] + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + @pytest.mark.parametrize("problem", get_problems()) + def test_correctness(self, dtype, reduction, problem): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = (num_tokens,) if isinstance(num_tokens, int) else num_tokens + + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + # ------ forward pass ------ # + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + torch_logprobs = self.TorchLinearCrossEntropy.apply( + hidden.view(-1, dim), weight, labels, self.tp_group, reduction + ) + + custom_logprobs = linear_cross_entropy( + hidden, weight, labels, tp_group=self.tp_group, reduction=reduction + ) + + torch.testing.assert_close(torch_logprobs, custom_logprobs) + + # ------- backward pass ------- # + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + (d_hidden_torch, d_weight_torch) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + (d_hidden_custom, d_weight_custom) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.testing.assert_close(d_hidden_torch, d_hidden_custom, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(d_weight_torch, d_weight_custom, atol=1e-4, rtol=1e-4) + + @pytest.mark.parametrize("problem", [((1, 4096), 129280, 7168)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("reduction", ["mean"]) + def test_performance(self, problem, dtype, reduction): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = (num_tokens,) if isinstance(num_tokens, int) else num_tokens + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch_fwd_latency = list() + torch_bwd_latency = list() + custom_fwd_latency = list() + custom_bwd_latency = list() + + iterations = 5 + for i in range(iterations): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + # ------ forward pass ------ # + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + start_event.record() + torch_logprobs = self.TorchLinearCrossEntropy.apply( + hidden.view(-1, dim), weight, labels, self.tp_group, reduction + ) + end_event.record() + torch.cuda.synchronize() + torch_fwd_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + custom_logprobs = linear_cross_entropy( + hidden, weight, labels, tp_group=self.tp_group, reduction=reduction + ) + end_event.record() + torch.cuda.synchronize() + custom_fwd_latency.append(start_event.elapsed_time(end_event)) + + # ------- backward pass ------- # + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + start_event.record() + (d_hidden_torch, d_weight_torch) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + torch_bwd_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (d_hidden_custom, d_weight_custom) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + custom_bwd_latency.append(start_event.elapsed_time(end_event)) + + # --- remove first latency due to warmup --- # + torch_fwd_latency = torch_fwd_latency[1:] + torch_bwd_latency = torch_bwd_latency[1:] + custom_fwd_latency = custom_fwd_latency[1:] + custom_bwd_latency = custom_bwd_latency[1:] + + if self.is_chief: + print() + print( + f"[INFO]: On problem {problem}, dtype {dtype}, reduction {reduction}, TP size {self.tp_world_size}:" + ) + print( + f"[INFO]: Torch forward latency: {sum(torch_fwd_latency) / len(torch_fwd_latency):.2f} ms" + ) + print( + f"[INFO]: Custom forward latency: {sum(custom_fwd_latency) / len(custom_fwd_latency):.2f} ms" + ) + print( + f"[INFO]: Torch backward latency: {sum(torch_bwd_latency) / len(torch_bwd_latency):.2f} ms" + ) + print( + f"[INFO]: Custom backward latency: {sum(custom_bwd_latency) / len(custom_bwd_latency):.2f} ms" + ) + + @pytest.mark.parametrize("problem", [((1, 4096), 129280, 7168)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("reduction", ["mean"]) + def test_storage(self, problem, dtype, reduction): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = (num_tokens,) if isinstance(num_tokens, int) else num_tokens + + if self.is_chief: + print() + print( + f"[INFO]: On problem {problem}, dtype {dtype}, reduction {reduction}, TP size {self.tp_world_size}:" + ) + + def torch_storage(): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + torch_logprobs = self.TorchLinearCrossEntropy.apply( + hidden.view(-1, dim), weight, labels, self.tp_group, reduction + ) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Torch Forward pass peak memory: {torch_max_memory:.2f} MB" + ) + + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + (d_hidden_torch, d_weight_torch) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Torch Backward pass peak memory: {torch_max_memory:.2f} MB" + ) + + def custom_storage(): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + custom_logprobs = linear_cross_entropy( + hidden, weight, labels, tp_group=self.tp_group, reduction=reduction + ) + torch.cuda.synchronize() + custom_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Custom Forward pass peak memory: {custom_max_memory:.2f} MB" + ) + + g_logprobs = torch.empty_like(custom_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + (d_hidden_custom, d_weight_custom) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.cuda.synchronize() + custom_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Custom Backward pass peak memory: {custom_max_memory:.2f} MB" + ) + + self.cleanup() + torch_storage() + self.cleanup() + custom_storage() + + +@pytest.mark.skipif( + "WORLD_SIZE" not in os.environ or int(os.environ["WORLD_SIZE"]) < 2, + reason="Requires torchrun with multiple GPUs", +) +class TestFusedLinearCrossEntropySequenceParallel: + @classmethod + def setup_class(cls): + if dist.is_initialized(): + cls.must_teardown = False + else: + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=int(os.environ["WORLD_SIZE"]), + rank=int(os.environ["RANK"]), + ) + cls.must_teardown = True + cls.tp_group = dist.group.WORLD + + cls.tp_rank = dist.get_rank(cls.tp_group) + cls.tp_world_size = dist.get_world_size(cls.tp_group) + cls.is_chief = cls.tp_rank == 0 + device = torch.device(f"cuda:{cls.tp_rank}") + torch.cuda.set_device(device) + print(f"[INFO]: TP rank: {cls.tp_rank}, TP world size: {cls.tp_world_size}") + + @classmethod + def teardown_class(cls): + if cls.must_teardown: + dist.destroy_process_group() + + @staticmethod + def timed_barrier(timeout_s=10): + import time + + work = torch.distributed.barrier(async_op=True) + t0 = time.time() + while not work.is_completed(): + if time.time() - t0 > timeout_s: + exit(1) + time.sleep(0.05) + work.wait() + + def cleanup(self): + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + import gc + + gc.collect() + torch.cuda.synchronize() + + @staticmethod + def torch_linear_cross_entropy_single_gpu( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: typing.Optional[str] = "mean", + ): + logits = hidden.to(torch.float32) @ weight.T.to(torch.float32) + logprobs = torch.nn.functional.cross_entropy( + logits.view(-1, logits.shape[-1]), labels.view(-1), reduction=reduction + ) + return logprobs.to(torch.float32) + + class TorchLinearCrossEntropy(torch.autograd.Function): + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + tp_group: torch.distributed.ProcessGroup, + reduction: typing.Optional[str] = "mean", + ): + tp_rank = 0 if tp_group is None else torch.distributed.get_rank(tp_group) + tp_world_size = 1 if tp_group is None else torch.distributed.get_world_size(tp_group) + + whole_hidden = torch.empty( + (hidden.shape[0] * tp_world_size, hidden.shape[-1]), + dtype=hidden.dtype, + device=hidden.device, + ) + dist.all_gather_into_tensor(whole_hidden, hidden, group=tp_group) + + logits = whole_hidden.to(torch.float32) @ weight.T.to(torch.float32) + + whole_logits = torch.empty( + (logits.shape[0], logits.shape[-1] * tp_world_size), + dtype=logits.dtype, + device=logits.device, + ) + whole_logits_ref = [ + whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]] + for i in range(tp_world_size) + ] + dist.all_gather(whole_logits_ref, logits, group=tp_group) + + logprobs = torch.nn.functional.cross_entropy( + whole_logits.view(-1, whole_logits.shape[-1]), labels.view(-1), reduction=reduction + ) + + # If we don't preserve whole_logits, + # we need to re-compute it in the backward pass + ctx.save_for_backward(whole_hidden, weight, labels) + ctx.tp_group = tp_group + ctx.reduction = reduction + ctx.tp_rank = tp_rank + ctx.tp_world_size = tp_world_size + + return logprobs.to(torch.float32) + + @staticmethod + def backward(ctx, g_logprobs: torch.Tensor): + whole_hidden, weight, labels = ctx.saved_tensors + tp_group = ctx.tp_group + reduction = ctx.reduction + tp_rank = ctx.tp_rank + tp_world_size = ctx.tp_world_size + + num_tokens, dim = whole_hidden.shape + + if reduction == "mean": + _g_logprobs = torch.broadcast_to(g_logprobs / num_tokens, (num_tokens,)) + elif reduction == "sum": + _g_logprobs = torch.broadcast_to(g_logprobs, (num_tokens,)) + else: + _g_logprobs = g_logprobs + + # re-compute whole_logits + logits = whole_hidden.to(torch.float32) @ weight.T.to(torch.float32) + whole_logits = torch.empty( + (logits.shape[0], logits.shape[-1] * tp_world_size), + dtype=logits.dtype, + device=logits.device, + ) + whole_logits_ref = [ + whole_logits[..., i * logits.shape[-1] : (i + 1) * logits.shape[-1]] + for i in range(tp_world_size) + ] + dist.all_gather(whole_logits_ref, logits, group=tp_group) + + one_hot = torch.zeros_like(whole_logits) + one_hot.scatter_(1, labels.view(-1).unsqueeze(-1), 1) + + pd = torch.nn.functional.softmax(whole_logits, dim=-1) + d_logits = (pd - one_hot) * _g_logprobs.unsqueeze(-1) + d_logits = d_logits.to(whole_hidden.dtype) + + local_size = weight.size(0) + local_d_logits = d_logits[:, tp_rank * local_size : (tp_rank + 1) * local_size] + + d_hidden = local_d_logits @ weight + local_d_weight = local_d_logits.T @ whole_hidden + + # dist.all_reduce( + # local_d_hidden, + # op=dist.ReduceOp.SUM, + # group=tp_group + # ) + + # split the local_d_hidden along the sequence length dimension + local_num_tokens = num_tokens // tp_world_size + # local_d_hidden = local_d_hidden[tp_rank * local_num_tokens : (tp_rank + 1) * local_num_tokens, :] + + local_d_hidden = torch.empty( + (local_num_tokens, dim), dtype=weight.dtype, device=weight.device + ) + dist.reduce_scatter_tensor( + local_d_hidden, d_hidden, op=dist.ReduceOp.SUM, group=tp_group + ) + return local_d_hidden, local_d_weight, None, None, None + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + @pytest.mark.parametrize("problem", [(256, 12928, 8192)]) + def test_torch_tp_vs_single_gpu(self, dtype, reduction, problem): + num_tokens, vocabsize, dim = problem + + hidden = ( + torch.empty((num_tokens, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint( + 0, vocabsize, (num_tokens * self.tp_world_size,), dtype=torch.long, device="cuda" + ) + + # ------------ forward pass ------------ # + dist.broadcast(labels, src=0, group=self.tp_group) + + # single GPU + whole_hidden = torch.empty( + (num_tokens * self.tp_world_size, dim), dtype=dtype, device="cuda" + ) + dist.all_gather_into_tensor(whole_hidden, hidden, group=self.tp_group) + whole_hidden = whole_hidden.clone().requires_grad_() + + whole_weight = torch.empty( + (vocabsize * self.tp_world_size, dim), dtype=dtype, device="cuda" + ) + whole_weight_view = [ + whole_weight[i * vocabsize : (i + 1) * vocabsize, :] for i in range(self.tp_world_size) + ] + dist.all_gather(whole_weight_view, weight, group=self.tp_group) + whole_weight = whole_weight.clone().requires_grad_() + logprobs_single_gpu = self.torch_linear_cross_entropy_single_gpu( + whole_hidden, whole_weight, labels, reduction=reduction + ) + + # TP + logprobs_tp = self.TorchLinearCrossEntropy.apply( + hidden, weight, labels, self.tp_group, reduction + ) + torch.testing.assert_close(logprobs_single_gpu, logprobs_tp) + + # ------------ backward pass ------------ # + g_logprobs = torch.empty_like(logprobs_single_gpu).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + # single GPU + (d_hidden_single_gpu, d_weight_single_gpu) = torch.autograd.grad( + (logprobs_single_gpu,), (whole_hidden, whole_weight), (g_logprobs,), retain_graph=False + ) + + # TP + (d_hidden_tp, d_weight_tp) = torch.autograd.grad( + (logprobs_tp,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + + local_d_hidden_single_gpu = d_hidden_single_gpu[ + self.tp_rank * hidden.shape[0] : (self.tp_rank + 1) * hidden.shape[0], : + ] + torch.testing.assert_close(local_d_hidden_single_gpu, d_hidden_tp, atol=1e-3, rtol=1e-3) + local_d_weight_single_gpu = d_weight_single_gpu[ + self.tp_rank * weight.shape[0] : (self.tp_rank + 1) * weight.shape[0], : + ] + torch.testing.assert_close(local_d_weight_single_gpu, d_weight_tp, atol=1e-3, rtol=1e-3) + + self.cleanup() + + @staticmethod + def get_problems(): + return [ + (80, 125, 64), + (80, 152064, 64), + (1024, 152064, 4096), + (4096, 15206, 1024), + ((1, 4096), 15206, 1024), + ((4, 1024), 15206, 1024), + ] + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + @pytest.mark.parametrize("problem", get_problems()) + def test_correctness(self, dtype, reduction, problem): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = ( + (num_tokens * self.tp_world_size,) + if isinstance(num_tokens, int) + else (num_tokens[0] * self.tp_world_size, *num_tokens[1:]) + ) + + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + # ------ forward pass ------ # + dist.broadcast(labels, src=0, group=self.tp_group) + + torch_logprobs = self.TorchLinearCrossEntropy.apply( + hidden.view(-1, dim), weight, labels, self.tp_group, reduction + ) + + custom_logprobs = linear_cross_entropy( + hidden, + weight, + labels, + tp_group=self.tp_group, + reduction=reduction, + sequence_parallel=True, + ) + + torch.testing.assert_close(torch_logprobs, custom_logprobs) + + # ------- backward pass ------- # + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + (d_hidden_torch, d_weight_torch) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + (d_hidden_custom, d_weight_custom) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + + # in case one GPU failed, and leading to hang + torch.testing.assert_close(d_hidden_torch, d_hidden_custom, atol=1e-3, rtol=1e-3) + torch.testing.assert_close(d_weight_torch, d_weight_custom, atol=1e-3, rtol=1e-3) + self.timed_barrier() + + self.cleanup() + + @pytest.mark.parametrize("problem", [((1, 1024), 129280, 7168)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("reduction", ["mean"]) + def test_performance(self, problem, dtype, reduction): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = ( + (num_tokens * self.tp_world_size,) + if isinstance(num_tokens, int) + else (num_tokens[0] * self.tp_world_size, *num_tokens[1:]) + ) + + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch_fwd_latency = list() + torch_bwd_latency = list() + custom_fwd_latency = list() + custom_bwd_latency = list() + + iterations = 5 + for i in range(iterations): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + # ------ forward pass ------ # + dist.broadcast(labels, src=0, group=self.tp_group) + + start_event.record() + torch_logprobs = self.TorchLinearCrossEntropy.apply( + hidden.view(-1, dim), weight, labels, self.tp_group, reduction + ) + end_event.record() + torch.cuda.synchronize() + torch_fwd_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + custom_logprobs = linear_cross_entropy( + hidden, + weight, + labels, + tp_group=self.tp_group, + reduction=reduction, + sequence_parallel=True, + ) + end_event.record() + torch.cuda.synchronize() + custom_fwd_latency.append(start_event.elapsed_time(end_event)) + + # ------- backward pass ------- # + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + start_event.record() + (d_hidden_torch, d_weight_torch) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + torch_bwd_latency.append(start_event.elapsed_time(end_event)) + + start_event.record() + (d_hidden_custom, d_weight_custom) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + end_event.record() + torch.cuda.synchronize() + custom_bwd_latency.append(start_event.elapsed_time(end_event)) + + # --- remove first latency due to warmup --- # + torch_fwd_latency = torch_fwd_latency[1:] + torch_bwd_latency = torch_bwd_latency[1:] + custom_fwd_latency = custom_fwd_latency[1:] + custom_bwd_latency = custom_bwd_latency[1:] + + if self.is_chief: + print() + print( + f"[INFO]: On problem {problem}, dtype {dtype}, reduction {reduction}, TP size {self.tp_world_size}, Sequence Parallel: True:" + ) + print( + f"[INFO]: Torch forward latency: {sum(torch_fwd_latency) / len(torch_fwd_latency):.2f} ms" + ) + print( + f"[INFO]: Custom forward latency: {sum(custom_fwd_latency) / len(custom_fwd_latency):.2f} ms" + ) + print( + f"[INFO]: Torch backward latency: {sum(torch_bwd_latency) / len(torch_bwd_latency):.2f} ms" + ) + print( + f"[INFO]: Custom backward latency: {sum(custom_bwd_latency) / len(custom_bwd_latency):.2f} ms" + ) + + @pytest.mark.parametrize("problem", [((1, 1024), 129280, 7168)]) + @pytest.mark.parametrize("dtype", [torch.bfloat16]) + @pytest.mark.parametrize("reduction", ["mean"]) + def test_storage(self, problem, dtype, reduction): + num_tokens, vocabsize, dim = problem + hidden_shape = (num_tokens, dim) if isinstance(num_tokens, int) else (*num_tokens, dim) + labels_shape = ( + (num_tokens * self.tp_world_size,) + if isinstance(num_tokens, int) + else (num_tokens[0] * self.tp_world_size, *num_tokens[1:]) + ) + + if self.is_chief: + print() + print( + f"[INFO]: On problem {problem}, dtype {dtype}, reduction {reduction}, TP size {self.tp_world_size}, Sequence Parallel: True:" + ) + + def torch_storage(): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + torch_logprobs = self.TorchLinearCrossEntropy.apply( + hidden.view(-1, dim), weight, labels, self.tp_group, reduction + ) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Torch Forward pass peak memory: {torch_max_memory:.2f} MB" + ) + + g_logprobs = torch.empty_like(torch_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + (d_hidden_torch, d_weight_torch) = torch.autograd.grad( + (torch_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.cuda.synchronize() + torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Torch Backward pass peak memory: {torch_max_memory:.2f} MB" + ) + + def custom_storage(): + hidden = ( + torch.empty(hidden_shape, dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + weight = ( + torch.empty((vocabsize, dim), dtype=dtype, device="cuda") + .uniform_(-0.1, 0.1) + .requires_grad_() + ) + labels = torch.randint(0, vocabsize, labels_shape, dtype=torch.long, device="cuda") + + dist.broadcast(hidden, src=0, group=self.tp_group) + dist.broadcast(labels, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + custom_logprobs = linear_cross_entropy( + hidden, + weight, + labels, + tp_group=self.tp_group, + reduction=reduction, + sequence_parallel=True, + ) + torch.cuda.synchronize() + custom_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Custom Forward pass peak memory: {custom_max_memory:.2f} MB" + ) + + g_logprobs = torch.empty_like(custom_logprobs).uniform_(-0.1, 0.1) + dist.broadcast(g_logprobs, src=0, group=self.tp_group) + + torch.cuda.reset_peak_memory_stats() + (d_hidden_custom, d_weight_custom) = torch.autograd.grad( + (custom_logprobs,), (hidden, weight), (g_logprobs,), retain_graph=False + ) + torch.cuda.synchronize() + custom_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 + if self.is_chief: + print( + f"[INFO]: On GPU {self.tp_rank}, Custom Backward pass peak memory: {custom_max_memory:.2f} MB" + ) + + self.cleanup() + torch_storage() + self.cleanup() + custom_storage() diff --git a/transformer_engine/common/cutedsl/__init__.py b/transformer_engine/common/cutedsl/__init__.py new file mode 100644 index 0000000000..0a35d614ce --- /dev/null +++ b/transformer_engine/common/cutedsl/__init__.py @@ -0,0 +1 @@ +"""Kernels written with CUTLASS DSL.""" \ No newline at end of file diff --git a/transformer_engine/common/cutedsl/linear_cross_entropy/__init__.py b/transformer_engine/common/cutedsl/linear_cross_entropy/__init__.py new file mode 100644 index 0000000000..38711fe3d5 --- /dev/null +++ b/transformer_engine/common/cutedsl/linear_cross_entropy/__init__.py @@ -0,0 +1,3 @@ +from transformer_engine.common.cutedsl.linear_cross_entropy import blackwell + +__all__ = ["blackwell"] \ No newline at end of file diff --git a/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/__init__.py b/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/__init__.py new file mode 100644 index 0000000000..be5aedef40 --- /dev/null +++ b/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/__init__.py @@ -0,0 +1,4 @@ +from .bwd_partial_dlogits import BwdPartialDlogits +from .fwd_mainloop import FwdMainLoop + +__all__ = ["BwdPartialDlogits", "FwdMainLoop"] \ No newline at end of file diff --git a/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/bwd_partial_dlogits.py b/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/bwd_partial_dlogits.py new file mode 100644 index 0000000000..3621e23c58 --- /dev/null +++ b/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/bwd_partial_dlogits.py @@ -0,0 +1,637 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from typing import Optional, Tuple, Type + +import cuda.bindings.driver as cuda # type: ignore +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline # type: ignore +import cutlass.utils as utils # type: ignore +import cutlass.utils.blackwell_helpers as sm100_utils # type: ignore +from cutlass.cute.nvgpu import cpasync, tcgen05 + +SM100_TMEM_CAPACITY_COLUMNS: int = 512 + + +def make_thread_cooperative_group(size: int, alignment: Optional[int] = None): + """ + Create a thread cooperative group. + """ + return pipeline.CooperativeGroup( + pipeline.Agent.Thread, size, alignment=alignment if alignment is not None else size + ) + + +class BwdPartialDlogits: + """ + This class implements the backward kernel for partial d_logits. + """ + + def __init__( + self, + reduction: int, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + use_2cta_instrs: bool = False, + mma_tiler_mn: Tuple[int, int] = (128, 256), + vocab_per_split: int = 512, + ): + self.REDUCTION: cutlass.Constexpr[cutlass.Int32] = cutlass.const_expr(reduction) + self.acc_dtype = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.mma_tiler = (*mma_tiler_mn, 1) + self.vocab_per_split = vocab_per_split + + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) + + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + self.threads_per_warp: int = 32 + + self.epi_warp_ids = (0, 1, 2, 3) + self.load_warp_ids = 4 + self.mma_warp_ids = 5 + self.empty_warp_ids = (6, 7) + + self.threads_per_cta: int = self.threads_per_warp * len( + (*self.epi_warp_ids, self.load_warp_ids, self.mma_warp_ids, *self.empty_warp_ids) + ) + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=self.threads_per_cta + ) + + self.buffer_align_bytes: int = 1024 + self.num_regs_other: int = 32 + self.num_regs_epi: int = 192 + + def _compute_grid( + self, + problem_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + cta_tiler: Tuple[int, int, int], + ) -> Tuple[int, int, int]: + cluster_shape_mnk = (*cluster_shape_mn, 1) + + grid = cute.round_up( + ( + cute.ceil_div(problem_mnk[0], cta_tiler[0]), + cute.ceil_div(self.vocab_per_split, cta_tiler[1]), + 1, + ), + cluster_shape_mnk, + ) + return grid + + def _compute_stages( + self, + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + ): + num_acc_stage = 1 + num_ab_stage = 4 + num_epi_stage_per_tile = 4 + return num_acc_stage, num_ab_stage, num_epi_stage_per_tile + + def _setup_attributes( + self, + tiled_mma: cute.TiledMma, + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + ): + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma.thr_id.shape,) + ) + + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + # it requires k-mode to be 128B aligned + mma_inst_tile_k: int = 4 + self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k) + + self.num_acc_stage, self.num_ab_stage, self.num_epi_stage_per_tile = self._compute_stages( + tiled_mma, self.mma_tiler, a_dtype, b_dtype + ) + self.tmem_alloc_cols = self.num_acc_stage * self.mma_tiler[1] + assert self.tmem_alloc_cols <= SM100_TMEM_CAPACITY_COLUMNS + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + @cute.kernel + def kernel( + self, + split_idx: cutlass.Int32, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB: cute.Tensor, + mLabels: cute.Tensor, + mDlogprobs: cute.Tensor, + mMaximum: cute.Tensor, + mAccu: cute.Tensor, + mDlogits_partial: cute.Tensor, + scalarNumValidTokens: cute.Pointer, + ignore_index: cutlass.Int64, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + cluster_layout_vmnk: cute.Layout, + problem_mnk: Tuple[int, int, int], + rank: cutlass.Int32, + ) -> None: + """ + The backward kernel for partial d_logits. + """ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + # FIXME: block swizzling applied here + pidm, pidn = bidx, bidy + + # FIXME: if 2 CTAs, modify here + cta_rank_in_cluster = 0 + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + + # prefetch tma descriptors + if warp_idx == self.load_warp_ids: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + ab_pipeline = pipeline.PipelineTmaUmma.create( + num_stages=self.num_ab_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_ids])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_ids])), + tx_count=self.tma_copy_ab_bytes, + barrier_storage=storage.load_ab_mbar_ptr.data_ptr(), + ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + + mma_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.num_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_ids])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.epi_warp_ids) + ), + barrier_storage=storage.mma_mbar_ptr.data_ptr(), + ) + mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + if warp_idx == self.empty_warp_ids[0]: + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, self.threads_per_warp * len(self.epi_warp_ids) + ) + cute.arch.mbarrier_init_fence() + + # -------- tensor partition ------------ # + # swizzle o [(tileM, tileK), loopM, loopK, stage] + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # swizzle o [(tileN, tileK), loopN, loopK, stage] + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + + # FIXME: if 2 CTAs, modify here + thr_mma = tiled_mma.get_slice(0) + # [MMA, loopM, loopK, stage] + tCsA = thr_mma.make_fragment_A(sA) + # [MMA, loopN, loopK, stage] + tCsB = thr_mma.make_fragment_B(sB) + + # [tileM, tileK, loopK] + gA = cute.local_tile( + mA, (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]), (pidm, None) + ) + # [vocab_per_split, dim] + mB_n = cute.local_tile( + mB, (self.vocab_per_split, cute.size(mB.layout.shape, mode=[1])), (split_idx, 0) + ) + # [tileN, tileK, loopK] + gB = cute.local_tile( + mB_n, (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]), (pidn, None) + ) + + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # just to make sure SMEM and GMEM tensor has the same size in the first rank + tCgA = thr_mma.partition_A(gA) + tCgB = thr_mma.partition_B(gB) + # [CPY, stage] & [CPY, loopK] + tTMAsA, tTMAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], # cta_coord, + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # [CPY, stage] & [CPY, loopK] + tTMAsB, tTMAgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], # cta_coord + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # ------ Allocate TMEM ------ # + tmem_holding_buf = storage.tmem_holding_buf + if warp_idx == self.empty_warp_ids[0]: + cute.arch.alloc_tmem( + self.tmem_alloc_cols, tmem_holding_buf, is_two_cta=self.use_2cta_instrs + ) + self.cta_sync_barrier.arrive_and_wait() + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + + tmem_shape = (128, self.tmem_alloc_cols) + acc_shape = thr_mma.partition_shape_C(tmem_shape) + tCtC_fake = thr_mma.make_fragment_C(acc_shape) + # [(tileM, tileN), loopM, loopN] + tCtC = cute.make_tensor(tmem_ptr, tCtC_fake.layout) + + # ------ Empty ------ # + if warp_idx in self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # ------ Load ------ # + if warp_idx == self.load_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + for k in cutlass.range(cute.size(gA, mode=[2])): + ab_pipeline.producer_acquire(ab_producer_state) + cute.copy( + tma_atom_a, + tTMAgA[(None, k)], + tTMAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + ) + cute.copy( + tma_atom_b, + tTMAgB[(None, k)], + tTMAsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + ) + ab_pipeline.producer_commit(ab_producer_state) + ab_producer_state.advance() + + # ------ MMA ------ # + if warp_idx == self.mma_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + mma_pipeline.producer_acquire(mma_producer_state) + + for k in cutlass.range(cute.size(gA, mode=[2])): + ab_pipeline.consumer_wait(ab_consumer_state) + + for kblock_idx in cutlass.range(cute.size(tCsA, mode=[2]), unroll_full=True): + cute.gemm( + tiled_mma, + cute.append_ones(tCtC[(None, None, mma_producer_state.index)]), + tCsA[(None, None, kblock_idx, ab_consumer_state.index)], + tCsB[(None, None, kblock_idx, ab_consumer_state.index)], + cute.append_ones(tCtC[(None, None, mma_producer_state.index)]), + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + ab_pipeline.consumer_release(ab_consumer_state) + ab_consumer_state.advance() + + mma_pipeline.producer_commit(mma_producer_state) + mma_producer_state.advance() + + # ------ EPI ------ # + if warp_idx in self.epi_warp_ids: + cute.arch.warpgroup_reg_alloc(self.num_regs_epi) + + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + utils.LayoutEnum.ROW_MAJOR, + self.acc_dtype, + self.acc_dtype, + (self.epi_tile[0], self.epi_tile[1] // self.num_epi_stage_per_tile), + self.use_2cta_instrs, + ) + # [tileM, subTileN, loopM, CntSubTileN, loopN] + tAcc_epi = cute.flat_divide( + tCtC[((None, None), 0, None)], + (self.epi_tile[0], self.epi_tile[1] // self.num_epi_stage_per_tile), + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTMEM_load_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + tTMEM_load_tAcc = cute.group_modes(tTMEM_load_tAcc, 3, cute.rank(tTMEM_load_tAcc) - 1) + + # predicates + cAcc = cute.make_identity_tensor(self.mma_tiler[:2]) + tCcAcc = thr_mma.partition_C(cAcc) + tCcAcc_epi = cute.flat_divide( + tCcAcc[((None, None), 0, None)], + (self.epi_tile[0], self.epi_tile[1] // self.num_epi_stage_per_tile), + ) + tTMEM_load_cAcc = thr_copy_t2r.partition_D(tCcAcc_epi) + tTMEM_load_cAcc_shape = cute.select(tTMEM_load_cAcc.shape, mode=[0, 1, 2]) + tTMEM_load_rAcc = cute.make_fragment(tTMEM_load_cAcc_shape, self.acc_dtype) + + copy_atom_g2r_int64 = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), mLabels.element_type + ) + copy_atom_g2r_fp32 = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), mDlogprobs.element_type + ) + epilogue_thread_layout = cute.make_layout((128, 1), stride=(1, 1)) + tiled_copy_g2r_int64 = cute.make_tiled_copy_tv( + copy_atom_g2r_int64, epilogue_thread_layout, cute.make_layout((1, 1)) + ) + tiled_copy_g2r_fp32 = cute.make_tiled_copy_tv( + copy_atom_g2r_fp32, epilogue_thread_layout, cute.make_layout((1, 1)) + ) + thr_copy_g2r_int64 = tiled_copy_g2r_int64.get_slice(tidx) + thr_copy_g2r_fp32 = tiled_copy_g2r_fp32.get_slice(tidx) + + # [tileM] + gLabels = cute.local_tile(mLabels, (self.epi_tile[0],), (pidm,)) + gMaximum = cute.local_tile(mMaximum, (self.epi_tile[0],), (pidm,)) + gAccu = cute.local_tile(mAccu, (self.epi_tile[0],), (pidm,)) + + # slice along M direction + tMCAcc = thr_copy_g2r_int64.partition_S(cAcc)[(None, None, 0)] + # [(1, 1), 1] + tMCAcc_mask = cute.make_fragment(tMCAcc.shape, cutlass.Boolean) + # to align shape with gMax and gAccu + tMCAcc_mask = cute.append_ones(tMCAcc_mask) + tMCAcc_mask[0] = cute.elem_less(pidm * self.epi_tile[0] + tidx, cute.size(mA, mode=[0])) + # [(1, 1), 1, 1] + tMgLabels = thr_copy_g2r_int64.partition_S(cute.append_ones(gLabels)) + tMrLabels = cute.make_fragment(tMgLabels.shape, tMgLabels.element_type) + cute.copy(tiled_copy_g2r_int64, tMgLabels, tMrLabels, pred=tMCAcc_mask) + tMgMaximum = thr_copy_g2r_fp32.partition_S(cute.append_ones(gMaximum)) + tMrMaximum = cute.make_fragment(tMgMaximum.layout, tMgMaximum.element_type) + cute.copy(tiled_copy_g2r_fp32, tMgMaximum, tMrMaximum, pred=tMCAcc_mask) + tMgAccu = thr_copy_g2r_fp32.partition_S(cute.append_ones(gAccu)) + tMrAccu = cute.make_fragment(tMgAccu.layout, tMgAccu.element_type) + cute.copy(tiled_copy_g2r_fp32, tMgAccu, tMrAccu, pred=tMCAcc_mask) + + tMrDlogprobs = cute.make_fragment(tMgAccu.layout, mDlogprobs.element_type) + if cutlass.const_expr(self.REDUCTION == 2): + # mean reduction + num_valid_tokens = cute.make_tensor(scalarNumValidTokens, layout=(1,)) + tMrDlogprobs[0] = mDlogprobs[0] / num_valid_tokens[0].to(cutlass.Float32) + elif cutlass.const_expr(self.REDUCTION == 1): + # sum reduction + tMrDlogprobs[0] = mDlogprobs[0] + else: + # no reduction + gDlogprobs = cute.local_tile(mDlogprobs, (self.epi_tile[0],), (pidm,)) + tMgDlogprobs = thr_copy_g2r_fp32.partition_S(cute.append_ones(gDlogprobs)) + cute.copy(tiled_copy_g2r_fp32, tMgDlogprobs, tMrDlogprobs, pred=tMCAcc_mask) + + tMrAccu[0] = cute.arch.rcp_approx(tMrAccu[0]) + tMrDlogprobs[0] *= tMrLabels[0] != ignore_index + tMr_d_acc_exp_logits = tMrDlogprobs[0] * tMrAccu[0] + + # ------ Partial output ------ # + # [tileM, tileN] + gDlogits_partial = cute.local_tile( + mDlogits_partial, (self.epi_tile[0], self.epi_tile[1]), (pidm, pidn) + ) + # blackwell supports STG.256 + copy_atom_r2g = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), gDlogits_partial.element_type, num_bits_per_copy=256 + ) + tiled_copy_r2g = cute.make_tiled_copy_tv( + copy_atom_r2g, epilogue_thread_layout, copy_atom_r2g.layout_dst_tv + ) + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + # [CPY, loopM, loopN] + tR2GCAcc = thr_copy_r2g.partition_S(cAcc) + tR2GCAcc_pred = cute.make_fragment(tR2GCAcc.shape, cutlass.Boolean) + for elem in cutlass.range(cute.size(tR2GCAcc_pred, mode=[0])): + for row in cutlass.range(cute.size(tR2GCAcc_pred, mode=[1])): + for col in cutlass.range(cute.size(tR2GCAcc_pred, mode=[2])): + tR2GCAcc_pred[elem, row, col] = cute.elem_less( + pidm * self.epi_tile[0] + tR2GCAcc[elem, row, col][0], problem_mnk[0] + ) and cute.elem_less( + split_idx * self.vocab_per_split + + pidn * self.epi_tile[1] + + tR2GCAcc[elem, row, col][1], + problem_mnk[1], + ) + + tR2GgDlogits = thr_copy_r2g.partition_D(gDlogits_partial) + + # for type conversion + dLogits_half = cute.make_fragment(tTMEM_load_rAcc.shape, tR2GgDlogits.element_type) + dLogits_half = cute.tiled_divide(dLogits_half, (cute.size(tR2GgDlogits, mode=[0]), 1)) + dLogits_half = cute.group_modes(dLogits_half, 2, cute.rank(dLogits_half)) + + mma_pipeline.consumer_wait(mma_consumer_state) + + block_vocab_left_idx: cutlass.Int64 = ( + split_idx * self.vocab_per_split + pidn * self.epi_tile[1] + ) + block_vocab_right_idx: cutlass.Int64 = min( + split_idx * self.vocab_per_split + (pidn + 1) * self.epi_tile[1], + min((split_idx + 1) * self.vocab_per_split, problem_mnk[1]), + ) + num_n_subtiles: cutlass.Int64 = cute.ceil_div( + (block_vocab_right_idx - block_vocab_left_idx), cute.size(tTMEM_load_rAcc, mode=[0]) + ) + for n_subtile in cutlass.range(num_n_subtiles): + cute.copy( + tiled_copy_t2r, + tTMEM_load_tAcc[(None, None, None, n_subtile, mma_consumer_state.index)], + tTMEM_load_rAcc, + ) + + for idx in cutlass.range(cute.size(tTMEM_load_rAcc, mode=[0]), unroll_full=True): + # exp_logits + tTMEM_load_rAcc[idx] = cute.exp(tTMEM_load_rAcc[idx] - tMrMaximum[0]) + + position: cutlass.Int64 = ( + rank * problem_mnk[1] + + split_idx * self.vocab_per_split + + pidn * self.epi_tile[1] + + n_subtile * cute.size(tTMEM_load_rAcc, mode=[0]) + + idx + ) + mask: cutlass.Boolean = ( + position == tMrLabels[0] and tMrLabels[0] != ignore_index + ) + # d_logits + tTMEM_load_rAcc[idx] *= tMr_d_acc_exp_logits + tTMEM_load_rAcc[idx] += mask * -tMrDlogprobs[0] + dLogits_half[idx] = tTMEM_load_rAcc[idx].to(dLogits_half.element_type) + + for idx in cutlass.range(cute.size(dLogits_half, mode=[1]), unroll_full=True): + copy_id = n_subtile * cute.size(dLogits_half, mode=[1]) + idx + cute.copy( + tiled_copy_r2g, + dLogits_half[(None, idx, None)], + tR2GgDlogits[(None, None, copy_id)], + pred=tR2GCAcc_pred[((0, None), None, copy_id)], + ) + + mma_pipeline.consumer_release(mma_consumer_state) + mma_consumer_state.advance() + + # ------ Deallocate TMEM ------ # + self.cta_sync_barrier.arrive_and_wait() + if warp_idx == self.empty_warp_ids[0]: + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.dealloc_tmem(tmem_ptr, self.tmem_alloc_cols, is_two_cta=self.use_2cta_instrs) + + @cute.jit + def __call__( + self, + split_idx: cutlass.Int32, + hidden: cute.Tensor, + weight: cute.Tensor, + labels: cute.Tensor, + dlogprobs: cute.Tensor, + maximum: cute.Tensor, + accu: cute.Tensor, + dlogits_partial: cute.Tensor, + scalarNumValidTokens: cute.Pointer, + ignore_index: cutlass.Int64, + rank: cutlass.Int32, + stream: cuda.CUstream, + ) -> None: + a_dtype: Type[cutlass.Numeric] = hidden.element_type + b_dtype: Type[cutlass.Numeric] = weight.element_type + + if cutlass.const_expr(hidden.element_type != weight.element_type): + raise RuntimeError( + f"data type don't match: {hidden.element_type} v.s. {weight.element_type}" + ) + if cutlass.const_expr(hidden.element_type not in [cutlass.Float16, cutlass.BFloat16]): + raise RuntimeError("hidden can only be FP16 or BF16") + if cutlass.const_expr(hidden.layout.shape[1] != weight.layout.shape[1]): + raise RuntimeError("K dimension doesn't match") + + problem_mnk = (hidden.layout.shape[0], weight.layout.shape[0], hidden.layout.shape[1]) + if cutlass.const_expr((problem_mnk[2] * a_dtype.width // 8) % 16 != 0): + raise RuntimeError(f"K dimension is not 16B aligned: {problem_mnk[2]}") + if cutlass.const_expr((problem_mnk[2] * b_dtype.width // 8) % 128 != 0): + raise RuntimeError(f"K dimension is not 128B aligned: {problem_mnk[2]}") + + grid = self._compute_grid( + problem_mnk=problem_mnk, + cluster_shape_mn=self.cluster_shape_mn, + cta_tiler=self.mma_tiler, + ) + + a_major_mode = utils.LayoutEnum.from_tensor(hidden).mma_major_mode() + b_major_mode = utils.LayoutEnum.from_tensor(weight).mma_major_mode() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + a_dtype, a_major_mode, b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler[:2] + ) + self._setup_attributes(tiled_mma, a_dtype, b_dtype) + + self.epi_tile = self.cta_tile_shape_mnk[:2] + + # Swizzle o [(tileM, tileK), loopM, loopK, stage] + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, self.mma_tiler, a_dtype, self.num_ab_stage + ) + # Swizzle o [(tileN, tileK), loopN, loopK, stage] + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, self.mma_tiler, b_dtype, self.num_ab_stage + ) + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + # Swizzle o [(tileM, tileK), loopM, loopK] + a_smem_layout = cute.select(a_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + hidden, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + # Swizzle o [(tileN, tileK), loopN, loopK] + b_smem_layout = cute.select(b_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + weight, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + ) + a_copy_size = cute.size_in_bytes(a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(b_dtype, b_smem_layout) + self.tma_copy_ab_bytes = a_copy_size + b_copy_size + + @cute.struct + class SharedStorage: + """ + The shared storage for the backward kernel. + """ + + load_ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + + tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + tmem_holding_buf: cutlass.Int32 + + sA: cute.struct.Align[ + cute.struct.MemRange[a_dtype, cute.cosize(a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[b_dtype, cute.cosize(b_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + self.kernel( + split_idx, + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + labels, + dlogprobs, + maximum, + accu, + dlogits_partial, + scalarNumValidTokens, + ignore_index, + a_smem_layout_staged, + b_smem_layout_staged, + self.cluster_layout_vmnk, + problem_mnk, + rank, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + ) diff --git a/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/fwd_mainloop.py b/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/fwd_mainloop.py new file mode 100644 index 0000000000..4f4b52cb08 --- /dev/null +++ b/transformer_engine/common/cutedsl/linear_cross_entropy/blackwell/fwd_mainloop.py @@ -0,0 +1,653 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Implementations of the fusion lm_head(Linear) + Cross-Entropy kernel +""" + +from typing import Tuple, Type + +import cuda.bindings.driver as cuda # type: ignore +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline # type: ignore +import cutlass.utils as utils # type: ignore +import cutlass.utils.blackwell_helpers as sm100_utils # type: ignore +from cutlass.cute.nvgpu import cpasync, tcgen05 + +SM100_TMEM_CAPACITY_COLUMNS: int = 512 + + +def make_thread_cooperative_group(size: int): + """ + Create a thread cooperative group. + """ + return pipeline.CooperativeGroup(pipeline.Agent.Thread, size, alignment=size) + + +class FwdMainLoop: + """ + This class implements the mainloop for forward process. + + Traits stored as attributes. + + :param acc_dtype: + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric] = cutlass.Float32, + use_2cta_instrs: bool = False, + mma_tiler_mn: Tuple[int, int] = (128, 256), + vocab_per_split: int = 512, + ): + """ + Configuration including: + - MMA instruction settings + - Cluster Shape + """ + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + # This is the shape covered by tiledMMA, not just single MMA instruction + self.mma_tiler = (*mma_tiler_mn, 1) + self.cta_tiler = (self.mma_tiler[0], vocab_per_split, self.mma_tiler[2]) + self.vocab_per_split = vocab_per_split + + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + self.cluster_shape_mn = (2, 1) if self.use_2cta_instrs else (1, 1) + + self.occupancy = 1 + # query SMEM capacity + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + + # the maximum columns per MMA is 256, and there is only one GEMM, so we can fully + # assign TMEM for that GEMM of different tiles. + # so 512 = 2 * 256 + + self.threads_per_warp: int = 32 + # 1 warp for loading, 1 warp for issuing MMA, 1 WG for storing + self.epi_warp_ids = (0, 1, 2, 3) + self.load_warp_ids = 4 + self.mma_warp_ids = 5 + self.empty_warp_ids = (6, 7) + + self.threads_per_cta: int = self.threads_per_warp * len( + (*self.epi_warp_ids, self.load_warp_ids, self.mma_warp_ids, *self.empty_warp_ids) + ) + + self.cta_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, num_threads=self.threads_per_cta + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, num_threads=self.threads_per_cta + ) + + self.buffer_align_bytes: int = 1024 + self.num_regs_other: int = 32 + self.num_regs_epi: int = 192 + + def _compute_stages( + self, + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + ): + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, mma_tiler, a_dtype, 1 # only single stage + ) + b_smem_layout_stage_one = sm100_utils.make_smem_layout_b(tiled_mma, mma_tiler, b_dtype, 1) + a_bytes_per_stage = cute.size_in_bytes(a_dtype, a_smem_layout_stage_one) + b_bytes_per_stage = cute.size_in_bytes(b_dtype, b_smem_layout_stage_one) + num_acc_stage = 2 + num_a_stage = 4 + num_b_stage = 4 + num_epi_stage_per_tile = 4 + + return num_acc_stage, num_a_stage, num_b_stage, num_epi_stage_per_tile + + def _setup_attributes( + self, + tiled_mma: cute.TiledMma, + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + ): + self.cluster_shape_mnk = (*self.cluster_shape_mn, 1) + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), (tiled_mma.thr_id.shape,) + ) + + # this is fixed for dense MMA, k=16 + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + # 16*4 = 64; 64 * sizeof(FP16) = 128Bytes + mma_inst_tile_k: int = 4 + self.mma_tiler = (self.mma_tiler[0], self.mma_tiler[1], mma_inst_shape_k * mma_inst_tile_k) + + self.num_acc_stage, self.num_a_stage, self.num_b_stage, self.num_epi_stage_per_tile = ( + self._compute_stages(tiled_mma, self.mma_tiler, a_dtype, b_dtype) + ) + self.tmem_alloc_cols = self.num_acc_stage * self.mma_tiler[1] + assert self.tmem_alloc_cols <= SM100_TMEM_CAPACITY_COLUMNS + + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB: cute.Tensor, + mLabels: cute.Tensor, + mMax: cute.Tensor, + mAccu: cute.Tensor, + mLogprobs: cute.Tensor, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + cluster_layout_vmnk: cute.Layout, + problem_mnk: Tuple[int, int, int], + ignore_index: cutlass.Int64, + rank: cutlass.Int32, + ): + """ + The forward kernel for the mainloop. + """ + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + tidx, _, _ = cute.arch.thread_idx() + bidx, bidy, _ = cute.arch.block_idx() + # FIXME: block swizzling applied here + pidm, pidn = bidx, bidy + + # prefetch tma descriptors + if warp_idx == self.load_warp_ids: + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_a) + cute.nvgpu.cpasync.prefetch_descriptor(tma_atom_b) + + # declare SMEM + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + ab_pipeline = pipeline.PipelineTmaUmma.create( + num_stages=self.num_a_stage, + producer_group=make_thread_cooperative_group(len([self.load_warp_ids])), + consumer_group=make_thread_cooperative_group(len([self.mma_warp_ids])), + tx_count=self.tma_copy_a_bytes + self.tma_copy_b_bytes, + barrier_storage=storage.load_ab_mbar_ptr.data_ptr(), + ) + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_a_stage + ) + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_a_stage + ) + + mma_pipeline = pipeline.PipelineUmmaAsync.create( + num_stages=self.num_acc_stage, + producer_group=make_thread_cooperative_group(len([self.mma_warp_ids])), + consumer_group=make_thread_cooperative_group( + self.threads_per_warp * len(self.epi_warp_ids) + ), + barrier_storage=storage.mma_mbar_ptr.data_ptr(), + ) + mma_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + mma_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr.data_ptr() + if warp_idx == self.empty_warp_ids[0]: + with cute.arch.elect_one(): + cute.arch.mbarrier_init( + tmem_dealloc_mbar_ptr, self.threads_per_warp * len(self.epi_warp_ids) + ) + cute.arch.mbarrier_init_fence() + + # -------- SMEM partition ------------ # + # swizzle o [(tileM, tileK), loopM, loopK, Stage] + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # swizzle o [(tileN, tileK), loopN, loopK, stage] + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + + # FIXME: if 2 CTAs, modify here + thr_mma = tiled_mma.get_slice(0) + # [MMA, loopM, loopK, stage] + tCsA = thr_mma.make_fragment_A(sA) + # [MMA, loopN, loopK, stage] + tCsB = thr_mma.make_fragment_B(sB) + + # ---------- GMEM partition ----------- # + # [tileM, tileK, loopK] + gA = cute.local_tile(mA, (self.mma_tiler[0], self.mma_tiler[2]), (pidm, None)) + + # [vocab_size_per_split, dim] + mB_n = cute.local_tile( + mB, (self.vocab_per_split, cute.size(mB.layout.shape, mode=[1])), (pidn, 0) + ) + + # [tileN, tileK, loopN, loopK] + gB = cute.local_tile(mB_n, (self.mma_tiler[1], self.mma_tiler[2]), (None, None)) + + # [MMA, tileCntM, tileCntK, loopK] + tCgA = thr_mma.partition_A(gA) + # [MMA, tileCntN, tileCntK, loopN, loopK] + tCgB = thr_mma.partition_B(gB) + + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # FIXME: if 2 CTAs, modify here + cta_rank_in_cluster = 0 + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + tTMAsA, tTMAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], # cta_coord, + a_cta_layout, + cute.group_modes(sA, 0, 3), # SMEM tensor + cute.group_modes(tCgA, 0, 3), # GMEM tensor + ) + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + tTMAsB, tTMAgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], # cta_coord + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # Allocate TMEM + tmem_holding_buf = storage.tmem_holding_buf + if warp_idx == self.empty_warp_ids[0]: + cute.arch.alloc_tmem( + self.tmem_alloc_cols, tmem_holding_buf, is_two_cta=self.use_2cta_instrs + ) + self.cta_sync_barrier.arrive_and_wait() + tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + + # [(tileM, tileN), loopM, loopN] + tmem_shape = (128, self.tmem_alloc_cols) + acc_shape = thr_mma.partition_shape_C(tmem_shape) + tCtC_fake = thr_mma.make_fragment_C(acc_shape) + tCtC = cute.make_tensor(tmem_ptr, tCtC_fake.layout) + + block_vocab_left_idx: cutlass.Int64 = pidn * self.vocab_per_split + block_vocab_right_idx: cutlass.Int64 = min( + (pidn + 1) * self.vocab_per_split, problem_mnk[1] + ) + num_n_tiles: cutlass.Int64 = cute.ceil_div( + (block_vocab_right_idx - block_vocab_left_idx), self.mma_tiler[1] + ) + + # /////// + # empty + # /////// + if warp_idx in self.empty_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + # /////// + # load + # /////// + if warp_idx == self.load_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + for n in cutlass.range(num_n_tiles): + for k in cutlass.range(cute.size(gA, mode=[2])): + ab_pipeline.producer_acquire(ab_producer_state) + cute.copy( + tma_atom_a, + tTMAgA[(None, k)], + tTMAsA[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + ) + cute.copy( + tma_atom_b, + tTMAgB[(None, n, k)], + tTMAsB[(None, ab_producer_state.index)], + tma_bar_ptr=ab_pipeline.producer_get_barrier(ab_producer_state), + ) + ab_pipeline.producer_commit(ab_producer_state) + ab_producer_state.advance() + + # /////// + # mma + # /////// + if warp_idx == self.mma_warp_ids: + cute.arch.warpgroup_reg_dealloc(self.num_regs_other) + + for n in cutlass.range(num_n_tiles): + # disable accumulate for the first tile + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + mma_pipeline.producer_acquire(mma_producer_state) + + for k in cutlass.range(cute.size(gA, mode=[2])): + ab_pipeline.consumer_wait(ab_consumer_state) + + for kblock_idx in cutlass.range(cute.size(tCsA, mode=[2]), unroll_full=True): + cute.gemm( + tiled_mma, + cute.append_ones(tCtC[(None, None, mma_producer_state.index)]), + tCsA[(None, None, kblock_idx, ab_consumer_state.index)], + tCsB[(None, None, kblock_idx, ab_consumer_state.index)], + cute.append_ones(tCtC[(None, None, mma_producer_state.index)]), + ) + # enable accumulate for the next tile + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + ab_pipeline.consumer_release(ab_consumer_state) + ab_consumer_state.advance() + + mma_pipeline.producer_commit(mma_producer_state) + mma_producer_state.advance() + + # ////////// + # epilogue + # ////////// + if warp_idx in self.epi_warp_ids: + cute.arch.warpgroup_reg_alloc(self.num_regs_epi) + + # epilog TMEM copy and partition + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + utils.LayoutEnum.ROW_MAJOR, # This is hard-coded + self.acc_dtype, + self.acc_dtype, + (self.epi_tile[0], self.epi_tile[1] // self.num_epi_stage_per_tile), + self.use_2cta_instrs, + ) + # [tileM, subTileN, loopM, CntSubTileN, loopN] + tAcc_epi = cute.flat_divide( + tCtC[((None, None), 0, None)], + (self.epi_tile[0], self.epi_tile[1] // self.num_epi_stage_per_tile), + ) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + tTMEM_load_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # [(pattern), loopM, loopN, CntTileM, CntTileN] + tTMEM_load_tAcc = cute.group_modes(tTMEM_load_tAcc, 3, cute.rank(tTMEM_load_tAcc) - 1) + + cAcc = cute.make_identity_tensor(self.mma_tiler[:2]) + tCcAcc = thr_mma.partition_C(cAcc) + # [tileM, subTileN, loopM, CntSubTileN, CntTileN] + tCcAcc_epi = cute.flat_divide( + tCcAcc[((None, None), 0, None)], + (self.epi_tile[0], self.epi_tile[1] // self.num_epi_stage_per_tile), + ) + tTMEM_load_cAcc = thr_copy_t2r.partition_D(tCcAcc_epi) + tTMEM_load_cAcc_shape = cute.select(tTMEM_load_cAcc.shape, mode=[0, 1, 2]) + + # epilogue layouts + epilogue_thread_layout = cute.make_layout((128, 1)) + copy_atom_g2r = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), mLabels.element_type) + tiled_copy_g2r = cute.make_tiled_copy(copy_atom_g2r, epilogue_thread_layout, (128, 1)) + thr_copy_g2r = tiled_copy_g2r.get_slice(tidx) + + copy_atom_r2g = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), cutlass.Float32) + tiled_copy_r2g = cute.make_tiled_copy(copy_atom_r2g, epilogue_thread_layout, (128, 1)) + thr_copy_r2g = tiled_copy_r2g.get_slice(tidx) + + # auxiliary tensors + # [tileM] + gLabels = cute.local_tile(mLabels, (self.epi_tile[0],), (pidm,)) + + tLabelsCAcc = thr_copy_g2r.partition_S(cAcc)[(None, None, 0)] + tLabelsCAcc_mask = cute.make_fragment(tLabelsCAcc.shape, cutlass.Boolean) + # [(1, 1), 1] + tLabelsCAcc_mask[0] = cute.elem_less(pidm * self.epi_tile[0] + tidx, problem_mnk[0]) + # to align shape with gMax and gAccu + tLabelsCAcc_mask = cute.append_ones(tLabelsCAcc_mask) + + # [(1, 1), 1, 1] + tLabelsgLabels = thr_copy_g2r.partition_S(cute.append_ones(gLabels)) + tLabelsrLabels = cute.make_fragment(tLabelsgLabels.shape, tLabelsgLabels.element_type) + cute.copy(tiled_copy_g2r, tLabelsgLabels, tLabelsrLabels, pred=tLabelsCAcc_mask) + valid_mask: cutlass.Boolean = (tLabelsrLabels[0] != ignore_index) and tLabelsCAcc_mask[ + 0 + ] + + # [tileM, 1] + gMax = cute.local_tile(mMax, (self.epi_tile[0], 1), (pidm, pidn)) + # [(CPYM, CPYN), loopM, loopN] + tR2GgMax = thr_copy_r2g.partition_D(gMax) + tR2GrMax = cute.make_fragment(tR2GgMax.shape, tR2GgMax.element_type) + tR2GrMax.fill(-1e30) + + # [tileM, 1] + gAccu = cute.local_tile(mAccu, (self.epi_tile[0], 1), (pidm, pidn)) + # [(CPYM, CPYN), loopM, loopN] + tR2GgAccu = thr_copy_r2g.partition_D(gAccu) + tR2GrAccu = cute.make_fragment(tR2GgAccu.shape, tR2GgAccu.element_type) + tR2GrAccu.fill(0.0) + + # [tileM, 1] + gLogprobs = cute.append_ones(cute.local_tile(mLogprobs, (self.epi_tile[0],), (pidm,))) + # [(CPYM, CPYN), loopM, loopN] + tR2GgLogprobs = thr_copy_r2g.partition_D(gLogprobs) + tR2GrLogprobs = cute.make_fragment(tR2GgLogprobs.shape, tR2GgLogprobs.element_type) + tR2GrLogprobs.fill(0.0) + + # [(tileN // num_epi_stage_per_tile, 1), 1, 1] + tTMEM_load_rAcc = cute.make_fragment(tTMEM_load_cAcc_shape, self.acc_dtype) + + for n in cutlass.range(num_n_tiles): + mma_pipeline.consumer_wait(mma_consumer_state) + + left: cutlass.Int64 = block_vocab_left_idx + n * self.epi_tile[1] + right: cutlass.Int64 = min( + (n + 1) * self.epi_tile[1] + block_vocab_left_idx, block_vocab_right_idx + ) + num_n_subtiles: cutlass.Int64 = cute.ceil_div( + (right - left), cute.size(tTMEM_load_rAcc, mode=[0]) + ) + for n_subtile in cutlass.range(num_n_subtiles): + cute.copy( + tiled_copy_t2r, + tTMEM_load_tAcc[(None, None, None, n_subtile, mma_consumer_state.index)], + tTMEM_load_rAcc, + ) + + for idx in cutlass.range( + cute.size(tTMEM_load_rAcc, mode=[0]), unroll_full=True + ): + local_position: cutlass.Int64 = ( + n * self.epi_tile[1] + + n_subtile * cute.size(tTMEM_load_rAcc, mode=[0]) + + idx + ) + if (block_vocab_left_idx + local_position) < block_vocab_right_idx: + _max_old = tR2GrMax[0] + tR2GrMax[0] = cute.arch.fmax(tR2GrMax[0], tTMEM_load_rAcc[idx]) + exp_logits = cute.exp(tTMEM_load_rAcc[idx] - tR2GrMax[0]) + coeff = cute.exp(_max_old - tR2GrMax[0]) + tR2GrAccu[0] = coeff * tR2GrAccu[0] + exp_logits + + position: cutlass.Int64 = ( + rank * problem_mnk[1] + pidn * self.vocab_per_split + local_position + ) + mask: cutlass.Boolean = valid_mask and (position == tLabelsrLabels[0]) + tR2GrLogprobs[0] += mask * tTMEM_load_rAcc[idx] + + mma_pipeline.consumer_release(mma_consumer_state) + mma_consumer_state.advance() + + cute.copy(tiled_copy_r2g, tR2GrMax, tR2GgMax, pred=tLabelsCAcc_mask) + cute.copy(tiled_copy_r2g, tR2GrAccu, tR2GgAccu, pred=tLabelsCAcc_mask) + + vocab_left_idx: cutlass.Int64 = rank * problem_mnk[1] + pidn * self.vocab_per_split + vocab_right_idx: cutlass.Int64 = rank * problem_mnk[1] + min( + (pidn + 1) * self.vocab_per_split, problem_mnk[1] + ) + valid: cutlass.Boolean = ( + tLabelsrLabels[0] >= vocab_left_idx and tLabelsrLabels[0] < vocab_right_idx + ) + tLabelsCAcc_mask[0] &= valid + + cute.copy(tiled_copy_r2g, tR2GrLogprobs, tR2GgLogprobs, pred=tLabelsCAcc_mask) + + # Dealloc TMEM + self.cta_sync_barrier.arrive_and_wait() + if warp_idx == self.empty_warp_ids[0]: + cute.arch.relinquish_tmem_alloc_permit() + cute.arch.dealloc_tmem(tmem_ptr, self.tmem_alloc_cols, is_two_cta=self.use_2cta_instrs) + + @staticmethod + def _compute_grid( + problem_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + cta_tiler: Tuple[int, int, int], + num_splits: int, + ) -> Tuple[int, int, int]: + + cluster_shape = (*cluster_shape_mn, 1) + + grid = cute.round_up( + (cute.ceil_div(problem_mnk[0], cta_tiler[0]), num_splits, 1), cluster_shape + ) + return grid + + @cute.jit + def __call__( + self, + hidden: cute.Tensor, + weight: cute.Tensor, + labels: cute.Tensor, + _logprobs: cute.Tensor, + _max: cute.Tensor, + _accu: cute.Tensor, + ignore_index: cutlass.Int64, + rank: cutlass.Int32, + stream: cuda.CUstream, + ) -> None: + a_dtype: Type[cutlass.Numeric] = hidden.element_type + b_dtype: Type[cutlass.Numeric] = weight.element_type + + if cutlass.const_expr(hidden.element_type != weight.element_type): + raise RuntimeError( + f"data type don't match: {hidden.element_type} v.s. {weight.element_type}" + ) + if cutlass.const_expr(hidden.element_type not in [cutlass.Float16, cutlass.BFloat16]): + raise RuntimeError("hidden can only be FP16 or BF16") + if cutlass.const_expr(hidden.layout.shape[1] != weight.layout.shape[1]): + raise RuntimeError("K dimension doesn't match") + + problem_mnk = (hidden.layout.shape[0], weight.layout.shape[0], hidden.layout.shape[1]) + if cutlass.const_expr((problem_mnk[2] * a_dtype.width // 8) % 16 != 0): + raise RuntimeError(f"K dimension is not 16B aligned: {problem_mnk[2]}") + + num_splits = cute.ceil_div(problem_mnk[1], self.vocab_per_split) + + grid = self._compute_grid( + problem_mnk=problem_mnk, + cluster_shape_mn=self.cluster_shape_mn, + cta_tiler=self.cta_tiler, + num_splits=num_splits, + ) + a_major_mode = utils.LayoutEnum.from_tensor(hidden).mma_major_mode() + b_major_mode = utils.LayoutEnum.from_tensor(weight).mma_major_mode() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + a_dtype, a_major_mode, b_major_mode, self.acc_dtype, self.cta_group, self.mma_tiler[:2] + ) + + self._setup_attributes(tiled_mma, a_dtype, b_dtype) + if cutlass.const_expr((problem_mnk[2] * a_dtype.width // 8) % 128 != 0): + raise RuntimeError(f"K dimension is not 128B aligned: {problem_mnk[2]}") + + self.epi_tile = self.mma_tiler[:2] + + # Swizzle o [(tileM, tileK), loopM, loopK, stage] + a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, self.mma_tiler, a_dtype, self.num_a_stage + ) + # Swizzle o [(tileN, tileK), loopN, loopK, stage] + b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, self.mma_tiler, b_dtype, self.num_b_stage + ) + + # TMA loading + tma_load_op = cpasync.CopyBulkTensorTileG2SOp(self.cta_group) + tma_store_op = cpasync.CopyBulkTensorTileS2GOp() + + # Swizzle o [(tileM, tileK), loopM, loopK] + a_smem_layout = cute.select(a_smem_layout_staged, mode=[0, 1, 2]) + # create tma copy atom for hidden, + # and the cooresponding tma descriptor tensor + tma_atom_a, tma_desc_a = cute.nvgpu.make_tiled_tma_atom_A( + tma_load_op, + hidden, # gmem_tensor + a_smem_layout, # SMEM layout + self.mma_tiler, # MMA tiler + tiled_mma, # TiledMMA + self.cluster_layout_vmnk.shape, # cluster_shape_vmnk + ) + # Swizzle o [(tileN, tileK), loopN, loopK] + b_smem_layout = cute.select(b_smem_layout_staged, mode=[0, 1, 2]) + tma_atom_b, tma_desc_b = cute.nvgpu.make_tiled_tma_atom_B( + tma_load_op, + weight, # gmem_tensor + b_smem_layout, # SMEM layout + self.mma_tiler, # MMA tiler + tiled_mma, # TiledMMA + self.cluster_layout_vmnk.shape, # cluster_shape_vmnk + ) + a_copy_size = cute.size_in_bytes(a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(b_dtype, b_smem_layout) + self.tma_copy_a_bytes = a_copy_size + self.tma_copy_b_bytes = b_copy_size + + assert self.num_a_stage == self.num_b_stage + + @cute.struct + class SharedStorage: + """ + The shared storage for the forward kernel. + """ + + # pipeline barriers, 2 = producer + consumer + load_ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_a_stage * 2] + mma_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tmem_dealloc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1] + # tmem holding buffer + tmem_holding_buf: cutlass.Int32 + # SMEM tensors + sA: cute.struct.Align[ + cute.struct.MemRange[a_dtype, cute.cosize(a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[b_dtype, cute.cosize(b_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # launch kernel + self.kernel( + tiled_mma, + tma_atom_a, + tma_desc_a, + tma_atom_b, + tma_desc_b, + labels, + _max, + _accu, + _logprobs, + a_smem_layout_staged, + b_smem_layout_staged, + self.cluster_layout_vmnk, + problem_mnk, + ignore_index, + rank, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + ) + return None diff --git a/transformer_engine/common/cutedsl/linear_cross_entropy/utils.py b/transformer_engine/common/cutedsl/linear_cross_entropy/utils.py new file mode 100644 index 0000000000..f1ee2eb3dc --- /dev/null +++ b/transformer_engine/common/cutedsl/linear_cross_entropy/utils.py @@ -0,0 +1,45 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import typing +from enum import Enum + + +class EntropyReductionEnum(Enum): + """ + Enum for the reduction method of cross entropy. + """ + + kNone = 0 + kSum = 1 + kMean = 2 + + +def str_to_reduction_enum(reduction: typing.Literal["none", "sum", "mean"]) -> EntropyReductionEnum: + """ + str -> EntropyReductionEnum + """ + _enum = EntropyReductionEnum.kNone + if reduction == "none": + _enum = EntropyReductionEnum.kNone + elif reduction == "sum": + _enum = EntropyReductionEnum.kSum + elif reduction == "mean": + _enum = EntropyReductionEnum.kMean + else: + raise ValueError(f"Invalid reduction: {reduction}") + return _enum + + +class BackwardMethodEnum(Enum): + """ + Enum for the backward method of linear cross entropy. + """ + + # two separate kernels for d_hidden and d_weight, respectively + kTwoKernels = 0 + # calculate partial d_logits along its N dimension + kDlogitsSplitN = 1 + # fuse d_hidden and d_weight into a single kernel + kFused = 2 diff --git a/transformer_engine/common/triton/linear_cross_entropy.py b/transformer_engine/common/triton/linear_cross_entropy.py new file mode 100644 index 0000000000..2a3483585f --- /dev/null +++ b/transformer_engine/common/triton/linear_cross_entropy.py @@ -0,0 +1,252 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Kernels for Linear Cross Entropy written with OpenAI Triton.""" + +import triton # type: ignore +import triton.language as tl # type: ignore + +# NOTE: tl.pointer_type() is not available in Triton 3.3.0 + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE_M": 1024}, num_stages=3, num_warps=32), + triton.Config({"BLOCK_SIZE_M": 2048}, num_stages=3, num_warps=32), + ], + key=["num_tokens"], +) +@triton.jit +def get_num_valid_tokens( + num_tokens: tl.int64, + ignore_index: tl.int64, + labels_ptr, #: tl.pointer_type(tl.int64), + stride_labels: tl.int64, + num_valid_tokens_ptr, #: tl.pointer_type(tl.int64), + BLOCK_SIZE_M: tl.constexpr, +): + """ + Calculate the number of valid tokens in the labels tensor. + """ + num_pid_m: tl.int64 = tl.cdiv(num_tokens, BLOCK_SIZE_M) + + num_valid_tokens: tl.int64 = tl.zeros((), dtype=tl.int64) + for m in range(0, num_pid_m): + offs_am = m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + labels = tl.load( + labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=ignore_index + ) + + valid_labels_mask = labels != ignore_index + num_valid_tokens += (tl.sum(valid_labels_mask.to(tl.int32), axis=0)).to(tl.int64) + tl.store(num_valid_tokens_ptr, num_valid_tokens) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) +@triton.jit +def forward_dp_epilogue( + num_tokens: tl.int64, + num_splits: tl.int64, # TODO: maybe this could be a constexpr + ignore_index: tl.int64, + labels_ptr, #: tl.pointer_type(tl.int64), + stride_labels: tl.int64, + num_valid_tokens_ptr, #: tl.pointer_type(tl.int64), + max_ptr, #: tl.pointer_type(tl.float32), + stride_max_m: tl.int64, + stride_max_n: tl.int64, + accu_ptr, #: tl.pointer_type(tl.float32), + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_max_ptr, #: tl.pointer_type(tl.float32), + stride_global_max: tl.int64, + global_accu_ptr, #: tl.pointer_type(tl.float32), + stride_global_accu: tl.int64, + global_logprobs_ptr, #: tl.pointer_type(tl.float32), + stride_global_logprobs: tl.int64, + global_logprobs_scalar_ptr, #: tl.pointer_type(tl.float32), + REDUCTION: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + forward epilogue in dp + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _max = tl.load( + max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _accu = tl.load( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1, return_indices=False) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # store maximum + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + # store accumulate + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + # update logprobs + labels = tl.load( + labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=ignore_index + ) + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + label_mask = labels != ignore_index + global_logprobs = tl.where(label_mask, global_logprobs, 0.0) + + if REDUCTION == 0: # no-reduction + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + elif REDUCTION == 1: # sum + global_logprobs_scalar = tl.sum(global_logprobs, axis=0) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + elif REDUCTION == 2: # mean + num_valid_tokens = tl.load(num_valid_tokens_ptr) + global_logprobs_scalar = tl.fdiv( + tl.sum(global_logprobs, axis=0), num_valid_tokens.to(tl.float32) + ) + tl.atomic_add(global_logprobs_scalar_ptr, global_logprobs_scalar) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) +@triton.jit +def forward_tp_epilogue( + num_tokens: tl.int64, + num_splits: tl.int64, + reduced_max_ptr, #: tl.pointer_type(tl.float32), + stride_reduced_max_m: tl.int64, + stride_reduced_max_n: tl.int64, + original_max_ptr, #: tl.pointer_type(tl.float32), + stride_original_max_m: tl.int64, + stride_original_max_n: tl.int64, + accu_ptr, #: tl.pointer_type(tl.float32), + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_max_ptr, #: tl.pointer_type(tl.float32), + stride_global_max: tl.int64, + global_accu_ptr, #: tl.pointer_type(tl.float32), + stride_global_accu: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + forward epilogue in tp + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load( + reduced_max_ptr + + offs_m[:, None] * stride_reduced_max_m + + offs_n[None, :] * stride_reduced_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _original_max = tl.load( + original_max_ptr + + offs_m[:, None] * stride_original_max_m + + offs_n[None, :] * stride_original_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _accu = tl.load( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + # local reduction + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # store + tl.store(global_max_ptr + offs_m * stride_global_max, global_max, mask=offs_m < num_tokens) + tl.store(global_accu_ptr + offs_m * stride_global_accu, global_accu, mask=offs_m < num_tokens) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def forward_tp_epilogue_update_logprobs( + num_tokens: tl.int64, + ignore_index: tl.int64, + num_valid_tokens_ptr, #: tl.pointer_type(tl.int64), + labels_ptr, #: tl.pointer_type(tl.int64), + stride_labels: tl.int64, + logprobs_ptr, #: tl.pointer_type(tl.float32), + stride_logprobs: tl.int64, + maximum_ptr, #: tl.pointer_type(tl.float32), + stride_maximum: tl.int64, + accumulate_ptr, #: tl.pointer_type(tl.float32), + stride_accumulate: tl.int64, + logprobs_scalar_ptr, #: tl.pointer_type(tl.float32), + REDUCTION: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, +): + """ + update logprobs in tp + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + logprobs = tl.load(logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens) + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load(accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens) + + labels = tl.load( + labels_ptr + offs_m * stride_labels, mask=offs_m < num_tokens, other=ignore_index + ) + label_mask = labels != ignore_index + + logprobs = maximum + tl.log(accumulate) - logprobs + logprobs = tl.where(label_mask, logprobs, 0.0) + + if REDUCTION == 0: # no-reduction + tl.store(logprobs_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens) + elif REDUCTION == 1: # sum + logprobs_scalar = tl.sum(logprobs, axis=0) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) + elif REDUCTION == 2: # mean + num_valid_tokens = tl.load(num_valid_tokens_ptr) + logprobs_scalar = tl.fdiv(tl.sum(logprobs, axis=0), num_valid_tokens.to(tl.float32)) + tl.atomic_add(logprobs_scalar_ptr, logprobs_scalar) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 5e1eb6954b..59eb12fa14 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -77,6 +77,7 @@ from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor import NVFP4Tensor +from transformer_engine.pytorch.linear_cross_entropy import linear_cross_entropy try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/cutedsl/__init__.py b/transformer_engine/pytorch/cutedsl/__init__.py new file mode 100644 index 0000000000..1906811477 --- /dev/null +++ b/transformer_engine/pytorch/cutedsl/__init__.py @@ -0,0 +1 @@ +"""Pytorch Wrappers for CUTLASS DSL kernels""" \ No newline at end of file diff --git a/transformer_engine/pytorch/cutedsl/linear_cross_entropy_blackwell.py b/transformer_engine/pytorch/cutedsl/linear_cross_entropy_blackwell.py new file mode 100644 index 0000000000..46b263cf59 --- /dev/null +++ b/transformer_engine/pytorch/cutedsl/linear_cross_entropy_blackwell.py @@ -0,0 +1,414 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import typing +from dataclasses import dataclass, field + +import cuda.bindings.driver as cuda # type: ignore +import cutlass +import cutlass.cute as cute +import torch +import torch.distributed as dist +import triton # type: ignore +from cutlass.cute.runtime import from_dlpack + +from transformer_engine.common.cutedsl.linear_cross_entropy import ( + blackwell, + utils, +) +from transformer_engine.common.triton import linear_cross_entropy as triton_kernels + + +@dataclass +class FwdConfig: + """ + The configuration for the forward pass. + """ + + _dedicated_stream: torch.cuda.Stream = field(default_factory=torch.cuda.Stream) + _dedicated_events: typing.List[torch.cuda.Event] = field(default_factory=list) + _initialized: bool = field(default=False) + _fwd_mainloop_kernels: typing.Dict[str, cute.kernel] = field(default_factory=dict) + + +@dataclass +class BwdConfig: + """ + The configuration for the backward pass. + """ + + _bwd_kernel: typing.Dict[str, cute.kernel] = field(default_factory=dict) + + +_fwd_config = FwdConfig() +_bwd_config = BwdConfig() + + +def forward( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + tp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + reduction: typing.Literal["none", "sum", "mean"] = "mean", + ignore_index: int = -100, + sequence_parallel: bool = False, +) -> typing.Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int, torch.Tensor]: + """ + forward host function + """ + tp_rank = 0 if tp_group is None else torch.distributed.get_rank(tp_group) + tp_world_size = 1 if tp_group is None else torch.distributed.get_world_size(tp_group) + in_tp_mode = (tp_group is not None) and (tp_world_size > 1) + + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + + # hidden could be [batch, seqlen, dim] or [seqlen, batch, dim] or [tokens, dim] + assert hidden.dim() == 2 or hidden.dim() == 3 + # weight must be [vocab_size, dim] + assert weight.dim() == 2 + # labels could be [batch, seqlen] or [seqlen, batch] or [tokens] + assert (hidden.dim() == 2 and labels.dim() == 1) or (hidden.dim() == 3 and labels.dim() == 2) + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + + hidden_view = hidden.view(-1, hidden.shape[-1]) + labels_view = labels.view(-1) + + assert (sequence_parallel and hidden_view.shape[0] * tp_world_size == labels_view.shape[0]) or ( + not sequence_parallel and hidden_view.shape[0] == labels_view.shape[0] + ) + assert hidden_view.shape[1] == weight.shape[1] + + global_hidden = hidden + if in_tp_mode and sequence_parallel: + partial_hidden_shape = hidden.shape + global_hidden_shape = (partial_hidden_shape[0] * tp_world_size, *partial_hidden_shape[1:]) + global_hidden = torch.empty(global_hidden_shape, dtype=hidden.dtype, device=hidden.device) + dist.all_gather_into_tensor(global_hidden, hidden, group=tp_group) + assert global_hidden.is_contiguous() + hidden_view = global_hidden.view(-1, global_hidden.shape[-1]) + + num_tokens, dim = hidden_view.shape + vocab_size, _ = weight.shape + + global _fwd_config + if not _fwd_config._initialized: + _fwd_config._dedicated_stream = torch.cuda.Stream(hidden.device) + _fwd_config._dedicated_events = [torch.cuda.Event() for _ in range(2)] + _fwd_config._initialized = True + + REDUCTION = utils.str_to_reduction_enum(reduction) + # declare logprobs + if REDUCTION == utils.EntropyReductionEnum.kNone: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + if in_tp_mode: + logprobs.zero_() + else: + logprobs = torch.zeros((), device=hidden.device, dtype=torch.float32) + # declare auxiliary tensors + maximum = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + accumulate = torch.empty_like(maximum, dtype=torch.float32) + num_valid_tokens = torch.empty((), device=hidden.device, dtype=torch.int64) + assert ( + maximum.is_contiguous() and accumulate.is_contiguous() and num_valid_tokens.is_contiguous() + ) + # declare intermediate tensors + # NOTE: this is a parameter for tuning + vocab_per_split = 512 * 6 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + _max = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + _accu = torch.empty((num_tokens, num_splits), device=hidden.device, dtype=torch.float32) + if REDUCTION == utils.EntropyReductionEnum.kNone: + _logprobs = logprobs + else: + _logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + if in_tp_mode: + _logprobs.zero_() + assert _max.is_contiguous() and _accu.is_contiguous() and _logprobs.is_contiguous() + + triton_kernels.get_num_valid_tokens[(1,)]( + num_tokens, ignore_index, labels_view, labels_view.stride(0), num_valid_tokens + ) + + # need to compile the kernel for the first time + hidden_packed = from_dlpack(hidden_view.detach(), assumed_align=16).mark_compact_shape_dynamic( + mode=0 + ) + weight_packed = from_dlpack(weight.detach(), assumed_align=16) + labels_packed = from_dlpack(labels_view.detach(), assumed_align=8).mark_compact_shape_dynamic( + mode=0 + ) + logprobs_packed = from_dlpack(_logprobs, assumed_align=16).mark_compact_shape_dynamic(mode=0) + _max_packed = from_dlpack(_max, assumed_align=8).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) + ) + _accu_packed = from_dlpack(_accu, assumed_align=8).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) + ) + cuda_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + # VocabSize and Dim are fixed for a given model, + # only the number of tokens can vary + key = f"vocab_size:{vocab_size}+dim:{dim}+dtype:{hidden_view.dtype}" + if _fwd_config._fwd_mainloop_kernels.get(key) is None: + fwd_mainloop_kernel = blackwell.fwd_mainloop.FwdMainLoop(vocab_per_split=vocab_per_split) + fwd_mainloop_compiled_kernel = cute.compile( + fwd_mainloop_kernel, + hidden_packed, + weight_packed, + labels_packed, + logprobs_packed, + _max_packed, + _accu_packed, + ignore_index, + tp_rank, + cuda_stream, + ) + _fwd_config._fwd_mainloop_kernels[key] = fwd_mainloop_compiled_kernel + else: + fwd_mainloop_compiled_kernel = _fwd_config._fwd_mainloop_kernels[key] + fwd_mainloop_compiled_kernel( + hidden_packed, + weight_packed, + labels_packed, + logprobs_packed, + _max_packed, + _accu_packed, + ignore_index, + tp_rank, + cuda_stream, + ) + + if not in_tp_mode: + + def grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + triton_kernels.forward_dp_epilogue[grid]( + num_tokens, + num_splits, + ignore_index, + labels_view, + labels_view.stride(0), + num_valid_tokens, + _max, + _max.stride(0), + _max.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + maximum, + maximum.stride(0), + accumulate, + maximum.stride(0), + _logprobs, + _logprobs.stride(0), + logprobs, + triton.language.constexpr(REDUCTION.value), + ) + else: + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=tp_group) + + torch.cuda.current_stream().record_event(_fwd_config._dedicated_events[0]) + with torch.cuda.stream(_fwd_config._dedicated_stream): + _fwd_config._dedicated_stream.wait_event(_fwd_config._dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=tp_group) + _fwd_config._dedicated_stream.record_event(_fwd_config._dedicated_events[1]) + + def grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + triton_kernels.forward_tp_epilogue[grid]( + num_tokens, + num_splits, + _max, + _max.stride(0), + _max.stride(1), + _max_backup, + _max_backup.stride(0), + _max_backup.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + maximum, + maximum.stride(0), + accumulate, + maximum.stride(0), + ) + # reduce accumulate + dist.all_reduce(accumulate, op=dist.ReduceOp.SUM, group=tp_group) + + # update logprobs + torch.cuda.current_stream().wait_event(_fwd_config._dedicated_events[1]) + triton_kernels.forward_tp_epilogue_update_logprobs[grid]( + num_tokens, + ignore_index, + num_valid_tokens, + labels_view, + labels_view.stride(0), + _logprobs, + _logprobs.stride(0), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + logprobs, + REDUCTION.value, + ) + + return logprobs, maximum, accumulate, num_valid_tokens, tp_rank, tp_world_size, global_hidden + + +def backward( + dlogprobs: torch.Tensor, + global_hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + accu: torch.Tensor, + num_valid_tokens: torch.Tensor, + reduction: typing.Literal["none", "sum", "mean"] = "mean", + ignore_index: int = -100, + tp_group: typing.Optional[dist.ProcessGroup] = None, + tp_rank: int = 0, + tp_world_size: int = 1, + sequence_parallel: bool = False, +) -> typing.Tuple[torch.Tensor, torch.Tensor]: + """ + backward host function + """ + in_tp_mode = (tp_group is not None) and (tp_world_size > 1) + + hidden_view = global_hidden.view(-1, global_hidden.shape[-1]) + labels_view = labels.view(-1) + + num_tokens, dim = hidden_view.shape + vocab_size, _ = weight.shape + + REDUCTION = utils.str_to_reduction_enum(reduction) + dlogprobs_view = dlogprobs.view(-1) + assert (REDUCTION == utils.EntropyReductionEnum.kNone and dlogprobs.shape == (num_tokens,)) or ( + REDUCTION != utils.EntropyReductionEnum.kNone and dlogprobs.dim() == 0 + ) + assert dlogprobs.is_contiguous() and dlogprobs.is_cuda + + assert ( + num_valid_tokens.dim() == 0 + and num_valid_tokens.is_cuda + and num_valid_tokens.dtype == torch.int64 + ) + + d_hidden = torch.empty_like(global_hidden) + d_weight = torch.empty_like(weight) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + # FIXME: implement different backward methods + _backward = utils.BackwardMethodEnum.kDlogitsSplitN + if _backward == utils.BackwardMethodEnum.kDlogitsSplitN: + vocab_per_split = 512 * 6 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_logits = torch.empty( + (num_tokens, vocab_per_split), device=global_hidden.device, dtype=global_hidden.dtype + ) + + hidden_packed = from_dlpack( + hidden_view.detach(), assumed_align=16 + ).mark_compact_shape_dynamic(mode=0) + weight_packed = from_dlpack(weight.detach(), assumed_align=16) + labels_packed = from_dlpack( + labels_view.detach(), assumed_align=8 + ).mark_compact_shape_dynamic(mode=0) + dlogprobs_packed = from_dlpack( + dlogprobs_view.detach(), assumed_align=8 + ).mark_compact_shape_dynamic(mode=0) + maximum_packed = from_dlpack(maximum.detach(), assumed_align=8).mark_compact_shape_dynamic( + mode=0 + ) + accu_packed = from_dlpack(accu.detach(), assumed_align=8).mark_compact_shape_dynamic(mode=0) + dlogits_packed = from_dlpack(_d_logits, assumed_align=32).mark_compact_shape_dynamic(mode=0) + scalarNumValidTokens_packed = cute.runtime.make_ptr( + cutlass.Int64, num_valid_tokens.data_ptr(), cute.AddressSpace.gmem, assumed_align=8 + ) + + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + key = f"vocab_size:{vocab_size}+dim:{dim}+reduction:{REDUCTION}+dtype:{hidden_view.dtype}" + if _bwd_config._bwd_kernel.get(key) is None: + bwd_kernel = blackwell.bwd_partial_dlogits.BwdPartialDlogits( + reduction=REDUCTION.value, vocab_per_split=vocab_per_split + ) + bwd_kernel_compiled = cute.compile( + bwd_kernel, + 0, # split_idx + hidden_packed, + weight_packed, + labels_packed, + dlogprobs_packed, + maximum_packed, + accu_packed, + dlogits_packed, + scalarNumValidTokens_packed, + ignore_index, + tp_rank, + stream, + ) + _bwd_config._bwd_kernel[key] = bwd_kernel_compiled + else: + bwd_kernel_compiled = _bwd_config._bwd_kernel.get(key) + + for split_idx in range(num_splits): + bwd_kernel_compiled( + split_idx, + hidden_packed, + weight_packed, + labels_packed, + dlogprobs_packed, + maximum_packed, + accu_packed, + dlogits_packed, + scalarNumValidTokens_packed, + ignore_index, + tp_rank, + stream, + ) + # remove padding areas + # cublas can handle non-contiguous tensors + # therefore, we do not need to contiguous the tensor + vocab_right_bound = ( + min((split_idx + 1) * vocab_per_split, vocab_size) - split_idx * vocab_per_split + ) + valid_d_logits = _d_logits[:, :vocab_right_bound] + + torch.addmm( + input=d_hidden.view(-1, dim), + mat1=valid_d_logits, + mat2=weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], + beta=(split_idx != 0), + alpha=1.0, + out=d_hidden.view(-1, dim), + ) + torch.matmul( + valid_d_logits.T, + hidden_view, + out=d_weight[split_idx * vocab_per_split : (split_idx + 1) * vocab_per_split, :], + ) + else: + raise NotImplementedError(f"Unsupported backward method: {_backward}") + + if in_tp_mode: + dist.all_reduce(d_hidden, op=dist.ReduceOp.SUM, group=tp_group) + if sequence_parallel: + partial_hidden_shape = ( + global_hidden.shape[0] // tp_world_size, + *global_hidden.shape[1:], + ) + partial_num_tokens = num_tokens // tp_world_size + d_hidden = d_hidden.view(-1, d_hidden.shape[-1])[ + tp_rank * partial_num_tokens : (tp_rank + 1) * partial_num_tokens, : + ] + d_hidden = d_hidden.view(partial_hidden_shape).clone() + + return d_hidden, d_weight diff --git a/transformer_engine/pytorch/linear_cross_entropy.py b/transformer_engine/pytorch/linear_cross_entropy.py new file mode 100644 index 0000000000..ff9c0217f8 --- /dev/null +++ b/transformer_engine/pytorch/linear_cross_entropy.py @@ -0,0 +1,232 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Linear Cross Entropy API +Fuse cross entropy with linear layer. +""" + +import typing + +import torch + + +class Platform: + """ + Singleton class for targeted GPU platform. + """ + + _instance: typing.Optional["Platform"] = None + + def __new__(cls) -> "Platform": + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self) -> None: + if getattr(self, "_initialized", False): + return + + assert torch.cuda.is_available(), "CUDA is not available" + device = torch.cuda.current_device() + cc = torch.cuda.get_device_capability(device) + + if cc[0] == 10: + from transformer_engine.pytorch.cutedsl import linear_cross_entropy_blackwell as gpu_entry + + self.forward_func: typing.Callable[..., typing.Any] = gpu_entry.forward + self.backward_func: typing.Callable[..., typing.Any] = gpu_entry.backward + else: + raise ValueError(f"Unsupported architecture: {cc[0]}") + + self._initialized = True + + +_platform = Platform() + + +class LinearCrossEntropy(torch.autograd.Function): + """ + This class implements a custom autograd function for linear and cross entropy, + whose equivalent logic in PyTorch is: + ```python + def torch_entropy(hidden, weight, labels): + logits = torch.matmul(hidden, weight) + logprobs = torch.nn.functional.cross_entropy(logits, labels) + return logprobs + ``` + """ + + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + tp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + reduction: typing.Literal["none", "sum", "mean"] = "mean", + ignore_index: int = -100, + sequence_parallel: bool = False, + ) -> torch.Tensor: + """ + The forward pass of the Linear Cross Entropy. + If tp_group is not None, the weight tensor to each TP rank should be + (global_vocab_size // world_size, dim). + Note that each of the ranks should get equal shards along the vocab_size dimension. + + Args: + @param hidden: the input tensor with shape (num_tokens, dim) + @param weight: the lm_head weight tensor with shape (local_vocab_size, dim) + @param labels: the labels tensor with shape (num_tokens,) + @param tp_group: the distributed process group for TP. + @param reduction: Default to "mean", and can be one of "none", "sum", "mean". + @param ignore_index: The index to ignore. Default to -100. + @param sequence_parallel: Whether to use sequence parallel. Default to False. + Returns: + @return: logprobs with shape + - either (num_tokens,) when reduction is "none" + - or (1,) when reduction is "mean" or "sum" + + tp_group is None ----------------------------------> DP + B + A C + tp_group is not None & sequence_parallel is False -> TP + B0 B1 + A C0 C1 + tp_group is not None & sequence_parallel is True --> SP + B0 B1 + A0 C0 XX + A1 XX C1 + + When tp_group is not None, the weight tensor will be split along the vocab_size + dimension, which means each rank will get equal shards along the global_vocab_size + dimension. Specifically, the weight tensor to each rank will be (local_vocab_size, dim). + And there is an assumption that each rank will get the same local_vocab_size. + + When sequence_parallel is True, the hidden tensor will be split along the + sequence length dimension, which means each rank will get equal shards along + the sequence length dimension. Specifically, the hidden tensor to each rank + will be (local_num_tokens, dim). And there is an assumption that each rank + will get the same local_num_tokens. + + In TP forward pass, the hidden tensor and label tensor shall be identical + among all TP ranks, and it's user's responsibility to ensure the hidden tensor + is identical among all TP ranks. Then this operation will produce identical + logprobs among all TP ranks. + + In TP backward pass, the gradient of the logprobs shall be identical among all + TP ranks, and it's user's responsibility to ensure the gradient of the logprobs + is identical among all TP ranks. Then this operation will produce distinct gradients + for the local weight tensor, and identical gradients for the hidden tensor. + + ```python + # ------------ forward pass ------------ # + hidden = tp_group.broadcast(hidden, src=0) # handled by framework + labels = tp_group.broadcast(labels, src=0) # handled by framework + logprobs = linear_cross_entropy(...) + # each rank will get the same logprobs + + # ------------ backward pass ------------ # + g_logprobs = tp_group.broadcast(g_logprobs, src=0) # handled by framework + d_hidden, d_weight = torch.autograd.grad(...) + # each rank will get the same d_hidden, + # and distinct d_weight for local weight shard + ``` + + In SP forward pass, the hidden tensor shall be split along the sequence length dimension, + and the label tensor shall be identical among all TP ranks. + Then this operation will produce identical logprobs among all TP ranks. + + In SP backward pass, the gradient of the logprobs shall be identical among all TP ranks, + Then this operation will produce distinct gradients for the local hidden tensor + and local weight tensor. + ```python + # ------------ forward pass ------------ # + hidden = global_hidden[tp_rank] # handled by framework + labels = tp_group.broadcast(labels, src=0) # handled by framework + logprobs = linear_cross_entropy(...) + # each rank will get the same logprobs + + # ------------ backward pass ------------ # + g_logprobs = tp_group.broadcast(g_logprobs, src=0) # handled by framework + d_hidden, d_weight = torch.autograd.grad(...) + # each rank will get distinct local d_hidden and d_weight + ``` + """ + with torch.cuda.nvtx.range("LinearCrossEntropy-forward"): + logprobs, _maximum, _acc, _num_valid_tokens, tp_rank, tp_world_size, global_hidden = ( + _platform.forward_func( + hidden, weight, labels, tp_group, reduction, ignore_index, sequence_parallel + ) + ) + ctx.save_for_backward(global_hidden, weight, labels, _maximum, _acc, _num_valid_tokens) + ctx.tp_group = tp_group + ctx.ignore_index = ignore_index + ctx.reduction = reduction + ctx.tp_rank = tp_rank + ctx.tp_world_size = tp_world_size + ctx.sequence_parallel = sequence_parallel + + return logprobs + + @staticmethod + def backward( + ctx, dlogprobs: torch.Tensor + ) -> typing.Tuple[torch.Tensor, torch.Tensor, None, None, None, None, None]: + """ + The backward pass of the Linear Cross Entropy. + Args: + dlogprobs (torch.Tensor): The gradient of the cross entropy, with shape + - either (num_tokens,) when reduction is "none" + - or (1,) when reduction is "mean" or "sum" + Returns: + dhidden (torch.Tensor): The gradient of the hidden. + dweight (torch.Tensor): The gradient of the weight. + """ + with torch.cuda.nvtx.range("LinearCrossEntropy-backward"): + (global_hidden, weight, labels, _maximum, _accu, _num_valid_tokens) = ctx.saved_tensors + + tp_group = ctx.tp_group + ignore_index = ctx.ignore_index + reduction = ctx.reduction + tp_rank = ctx.tp_rank + tp_world_size = ctx.tp_world_size + sequence_parallel = ctx.sequence_parallel + + d_hidden, d_weight = _platform.backward_func( + dlogprobs, + global_hidden, + weight, + labels, + _maximum, + _accu, + _num_valid_tokens, + reduction, + ignore_index, + tp_group, + tp_rank, + tp_world_size, + sequence_parallel, + ) + + return d_hidden, d_weight, None, None, None, None, None + + +def linear_cross_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + tp_group: typing.Optional[torch.distributed.ProcessGroup] = None, + reduction: typing.Literal["none", "sum", "mean"] = "mean", + ignore_index: int = -100, + sequence_parallel: bool = False, +) -> torch.Tensor: + """ + helper function for linear cross entropy. + """ + _impl = LinearCrossEntropy.apply + return _impl(hidden, weight, labels, tp_group, reduction, ignore_index, sequence_parallel) + + +__all__ = ["linear_cross_entropy", "LinearCrossEntropy"]