From 5b3092a0e40654436bec5ea0a0b0f7ad2887b20d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 19 Sep 2025 10:17:36 -0700 Subject: [PATCH 001/141] Changed VERSION to 2.9.0.dev0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 81006d78c..8bfb1cae8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.8.0.dev0 +2.9.0.dev0 From 57b4d7bc0350917cd2122b07f144bfeb4c04eb0b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 22 Sep 2025 12:58:24 -0400 Subject: [PATCH 002/141] [JAX] Remove import jax.extend.ffi (#2193) * remove import jax.extend.ffi Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/activation.py | 7 +------ transformer_engine/jax/cpp_extensions/attention.py | 9 +-------- transformer_engine/jax/cpp_extensions/base.py | 8 +------- transformer_engine/jax/cpp_extensions/normalization.py | 8 +------- transformer_engine/jax/cpp_extensions/quantization.py | 8 +------- transformer_engine/jax/cpp_extensions/softmax.py | 8 +------- 6 files changed, 6 insertions(+), 42 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index cdda20166..d0a4e58fb 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,11 +5,10 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec @@ -37,10 +36,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index df89174b2..625f42049 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -8,11 +8,10 @@ from dataclasses import dataclass, replace from functools import partial, reduce from typing import Optional, Tuple -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes, lax +from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding from jax.experimental.custom_partitioning import SdyShardingRule @@ -49,12 +48,6 @@ ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - - __all__ = [ "FusedAttnHelper", "fused_attn_fwd", diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index c05570566..cc8a07860 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -7,22 +7,16 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial -from packaging import version from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch +from jax import ffi -import jax import transformer_engine_jax -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - class BasePrimitive(metaclass=ABCMeta): """ diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 7a978c1b7..351767e36 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -7,11 +7,10 @@ import operator from functools import partial, cache, reduce from typing import Optional, Union -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -38,11 +37,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "layernorm_fwd", diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 1813734b5..895913d0a 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,11 +6,10 @@ from functools import reduce from typing import Tuple, Optional, Union import math -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec @@ -41,11 +40,6 @@ NoScaleTensor, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 43cb11a08..575a2dd3a 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -6,22 +6,16 @@ from functools import partial, reduce import operator import warnings -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.sharding import PartitionSpec, NamedSharding from .base import BasePrimitive, register_primitive from .misc import get_padded_spec, check_valid_batch_dims from ..softmax import SoftmaxType -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "scaled_softmax_fwd", From 5e4e0b2c378d2b1ec2ee65dfa85124e1dd805389 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 22 Sep 2025 12:53:25 -0700 Subject: [PATCH 003/141] [PyTorch] Add sink attention support from cuDNN (#2148) * first draft; debug plan failure Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * debug uid error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak params Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add grad in output Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix prints in test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * address review comments Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix unfused grad; add softmax_type; add sink to bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix padding mask; add swa tests; remove requires_grad for off-by-one Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Apply 1 suggestion(s) to 1 file(s) Co-authored-by: Chen Cui Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix indent Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix non-determinism and shapes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add GQA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add CP A2A; dq/dk mismatches Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP A2A; need cleaner solution Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CP A2A; pending cudnn kernel change Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix world size in unit test; avoid thd format Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kernel_backend, dtype in unit test; fix head_dim for FP8 Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix thd logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 context Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tweak CP logging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * allow no_mask/padding for SWA(left,0) Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "allow no_mask/padding for SWA(left,0)" This reverts commit 08b4ccc67a08b6882080b06aa715f541bb832aca. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add softmax_type to Jax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add cuDNN version control Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * prettify tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * skip 9.13 for MLA, non 192/128 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename compare_with_error Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * small cleanups and improvements Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix minor CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * force sink/dsink to be float32 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch FE to GH FE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return to GH TE main FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.14.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up before CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * bump up cudnn version Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add backend selection guard for unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add docstring for softmax type enums in C Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: Chen Cui Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../attention/run_attention_with_cp.py | 398 +++++++++------ tests/pytorch/attention/test_attention.py | 273 ++++++---- .../attention/test_attention_with_cp.py | 46 +- tests/pytorch/attention/test_kv_cache.py | 1 - tests/pytorch/utils.py | 42 +- .../common/fused_attn/fused_attn.cpp | 216 ++++---- .../fused_attn_f16_arbitrary_seqlen.cu | 467 ++++++++++-------- .../fused_attn_f16_arbitrary_seqlen.h | 61 +-- .../common/fused_attn/fused_attn_fp8.cu | 2 + transformer_engine/common/fused_attn/utils.h | 12 +- .../include/transformer_engine/fused_attn.h | 110 +++-- .../common/util/pybind_helper.h | 4 + .../jax/csrc/extensions/attention.cpp | 167 ++++--- .../dot_product_attention/backends.py | 82 +-- .../dot_product_attention/context_parallel.py | 130 ++++- .../dot_product_attention.py | 44 +- .../attention/dot_product_attention/utils.py | 55 +++ .../pytorch/attention/multi_head_attention.py | 14 + .../pytorch/cpp_extensions/fused_attn.py | 23 + transformer_engine/pytorch/csrc/extensions.h | 25 +- .../pytorch/csrc/extensions/attention.cpp | 139 +++--- transformer_engine/pytorch/module/base.py | 15 +- transformer_engine/pytorch/transformer.py | 14 + 24 files changed, 1515 insertions(+), 827 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index deda80e53..1a7b4b78d 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 +Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 0ad64204f..7e47e7df8 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -17,88 +17,18 @@ from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling +from utils import ModelConfig, compare_and_assert + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} -def run_dpa_with_cp( - dtype="bf16", - model=None, - qkv_format="bshd", - kernel_backend="FlashAttention", - cp_comm_type="p2p", - fp8_mha=False, +def generate_input_shapes( + qkv_format: str, + config: ModelConfig, + world_size: int, + kernel_backend: str, ): - """Test DotProductAttention module with context parallelism""" - - # args are passed as strings - fp8_mha = fp8_mha == "True" - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - if kernel_backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] - if kernel_backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] - - assert config.attn_mask_type in [ - "causal", - "no_mask", - ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if qkv_format == "thd": - if "causal" in config.attn_mask_type: - config.attn_mask_type = "padding_causal" - else: - config.attn_mask_type = "padding" - - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) - - print(f"[INFO] world_size:{world_size}, rank:{rank}") - - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - - # create flash attn comm group for CP - cp_comm_ranks = range(world_size) - assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert ( - world_size % 2 == 0 - ), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!" - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - - if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) - - # instantiate core attn module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - ) - core_attn = core_attn.cuda() - - # create flash attn inputs if qkv_format == "bshd": q_input_shape = ( config.batch_size, @@ -191,34 +121,158 @@ def run_dpa_with_cp( cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + assert False, f"{qkv_format=} is not supported!" + + return ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) + + +def get_tols(config, dtype): + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + atol = 2.5e-2 + rtol = 2.5e-2 + else: + atol = 3.5e-2 + rtol = 3.5e-2 + rmse_tol = 0.01 + elif dtype == "fp16": + atol = 5e-3 + rtol = 5e-3 + rmse_tol = 0.01 + elif dtype == "fp8": + atol = 5e-1 + rtol = 5e-1 + rmse_tol = 0.1 + else: + assert False, f"{dtype=} is not supported!" + + return atol, rtol, rmse_tol + +def run_dpa_with_cp( + dtype="bf16", + model=None, + qkv_format="bshd", + kernel_backend="FlashAttention", + cp_comm_type="p2p", + fp8_mha=False, + log_level=logging.WARNING, +): + """Test DotProductAttention module with context parallelism""" + logging.root.setLevel(log_level) + + # set up environment variables and config + fp8_mha = fp8_mha == "True" + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if kernel_backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + config = model_configs_flash_attn[model] + if kernel_backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + config = model_configs_fused_attn[model] + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type=} is not supported!" + if qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" + + # set up distributed group + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + torch.cuda.set_device(device) + logging.info(f"[Rank {rank}] Setup: world_size {world_size}") + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + # set up communication group for CP + cp_comm_ranks = range(world_size) + assert rank in cp_comm_ranks + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" + " = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + cp_comm_sub_groups = [] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) + if dtype == "fp8": + fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + ).cuda() + if config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() + for x in [q, k, v]: + x.requires_grad = True + dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), - ) + if fp8_mha: + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) - # create flash attention bias if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() else: bias = None - # run core_attn without CP - for x in [q, k, v]: - x.requires_grad = True - + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) else: fp8_context = nullcontext() - with fp8_context: out = core_attn( q, @@ -236,8 +290,30 @@ def run_dpa_with_cp( out.backward(dout_fp8) else: out.backward(dout) + dq, dk, dv = q.grad, k.grad, v.grad + d_softmax_offset = None + if config.softmax_type != "vanilla": + d_softmax_offset = core_attn.softmax_offset.grad - # run core_attn wit CP + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") + + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.reset_fp8_meta_tensors() + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() + + # set up inputs q_, k_, v_, dout_, *rest = [ x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) ] @@ -267,8 +343,6 @@ def run_dpa_with_cp( ) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( @@ -276,19 +350,8 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - - if dtype == "fp8": - core_attn.reset_fp8_meta_tensors() - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) - else: - fp8_context = nullcontext() + # run attention with fp8_context: out_ = core_attn( q_, @@ -306,18 +369,23 @@ def run_dpa_with_cp( out_.backward(dout_fp8_) else: out_.backward(dout_) - if fp8_mha: assert isinstance(out, Float8Tensor) assert isinstance(out_, Float8Tensor) out = out.dequantize() out_ = out_.dequantize() - for x in [out_, q_.grad, k_.grad, v_.grad]: - assert torch.all(~torch.isnan(x)) - assert torch.all(~torch.isinf(x)) - - # compare results with and without CP + # get outputs + dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + d_softmax_offset_ = None + if config.softmax_type != "vanilla": + d_softmax_offset_ = core_attn.softmax_offset.grad.clone() + for x in [out_, dq_, dk_, dv_, d_softmax_offset_]: + if x is not None: + assert torch.all(~torch.isnan(x)) + assert torch.all(~torch.isinf(x)) + + ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -373,56 +441,70 @@ def run_dpa_with_cp( ).item() == 0 ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - elif dtype == "fp16": - tols = dict(atol=5e-3, rtol=5e-3) - elif dtype == "fp8": - tols = dict(atol=5e-1, rtol=5e-1) - rmse_tol = 0.1 - else: - assert False, f"{dtype} is an unsupported dtype!" - - def _rmse(a, b): - return torch.sqrt((a - b).square().mean()).item() - - def _error(a, b): - if dtype != "fp8": - torch.testing.assert_close(a, b, **tols) - else: - try: - torch.testing.assert_close(a, b, **tols) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert ( - rmse < rmse_tol * rmse_range - ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - if qkv_format == "bshd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[:, 0], b[:, 0]) - _error(a[:, 1], b[:, 1]) - elif qkv_format == "sbhd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[0], b[0]) - _error(a[1], b[1]) - elif qkv_format == "thd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a, b) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] + names = ["out", "dq", "dk", "dv", "d_softmax_offset"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i]: + if qkv_format == "bshd": + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + else: + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + # destroy distribution group dist.destroy_process_group() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 56bfa1423..a5c345779 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. import logging -import math import os import sys import pathlib @@ -50,27 +49,35 @@ sys.path.append(str(_current_file.parent.parent)) from utils import ( reset_rng_states, + compare_and_assert, ModelConfig, dtype_tols, get_available_attention_backends, ) -# Only run FP8 tests on H100 +# Check if hardware supports FP8 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() +# Reset RNG seed and states seed = 1234 -# Reset RNG states reset_rng_states() +# Reset FP8 global state manager @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield fp8.FP8GlobalStateManager.reset() +# Define F16 data types to test +param_types = [torch.float16] +if is_bf16_compatible(): + param_types.append(torch.bfloat16) +param_types_lean = [torch.bfloat16] + model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), "base_2_0": ModelConfig(2, 2048, 24, 128), @@ -86,12 +93,6 @@ def reset_global_fp8_state(): } -param_types = [torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher - param_types.append(torch.bfloat16) -param_types_lean = [torch.bfloat16] - - @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -125,12 +126,12 @@ def test_dot_product_attention( config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + # Get backends is_training = True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -141,7 +142,6 @@ def test_dot_product_attention( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -227,6 +227,7 @@ def test_dot_product_attention( is_training, ) + # Compare results logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") @@ -259,23 +260,102 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_softmax = { + # test: ModelConfig(b, sq, hq, dqk) + "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), + "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"), + "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"), + "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "softmax_2_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), + "softmax_2_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"), + "softmax_3_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one" + ), + "softmax_3_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable" + ), + "softmax_4_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal" + ), + "softmax_4_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="off-by-one", + ), + "softmax_4_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="learnable", + ), + "softmax_5_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal" + ), + "softmax_5_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="off-by-one", + ), + "softmax_5_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="learnable", + ), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention( + dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False + ) + + model_configs_mla = { - # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 - "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 + # test: ModelConfig(b, sq, hq, dqk) + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), "mla_2_1": ModelConfig( 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 - ), # cross, 1 + ), "mla_2_2": ModelConfig( 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 - ), # cross, 1 - "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference - "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference + ), + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), } @@ -289,7 +369,7 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), @@ -344,18 +424,16 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), - "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped - "bias_1_5": ModelConfig( - 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" - ), # skipped + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), + "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"), "bias_2_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_1": ModelConfig( 2, 128, @@ -364,10 +442,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_2_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_3": ModelConfig( 2, 2048, @@ -376,13 +454,11 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped - "bias_2_4": ModelConfig( - 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), + "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"), "bias_2_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), "bias_3_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), @@ -400,14 +476,14 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), "bias_3_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_1": ModelConfig( 2, 128, @@ -416,10 +492,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_3": ModelConfig( 2, 2048, @@ -428,10 +504,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_4": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_5": ModelConfig( 2, 2048, @@ -440,7 +516,7 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="alibi", - ), # skipped + ), } @@ -454,7 +530,7 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { - # test: b, h, hg, d, sq, skv, p, + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), @@ -492,7 +568,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "swa_1_1": ModelConfig(2, 2048, 16, 64), "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), @@ -532,7 +608,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { - # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type + # test: ModelConfig(b, sq, hq, dqk) "alibi_1_0": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" ), @@ -586,7 +662,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 128, 16, 64), "layout_0_1": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -634,7 +710,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), @@ -726,7 +802,6 @@ def _run_dot_product_attention( is_training: bool, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" - # Set RNG and environment varables reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" @@ -989,9 +1064,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tp_group=None, layer_number=1, attention_type=config.attn_type, + softmax_type=config.softmax_type, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() + if is_training and config.softmax_type != "vanilla": + block.softmax_offset.requires_grad = True # Run a forward and backward pass if backend in ["FlashAttention", "UnfusedDotProductAttention"]: @@ -1026,12 +1104,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ) if is_training: out.backward(d_out) - + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = block.softmax_offset.grad if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1060,18 +1140,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) + return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) else: - return out_orig, (None, None, None) + return out_orig, (None, None, None, d_softmax_offset) else: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, (None, None, None, d_softmax_offset) model_configs_te_layer = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), "te_1_1": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -1436,6 +1516,7 @@ def _run_transformer_layer( model_configs_fp8_extra_state = { + # test: ModelConfig(b, sq, hq, dqk) "large": ModelConfig(2, 128, 4, 128, num_layers=1), } @@ -1445,7 +1526,8 @@ def _run_transformer_layer( @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): +def test_dpa_fp8_extra_state(model, dtype): + """Test DotProductAttention module in FP8 with checkpointing""" config = model_configs_fp8_extra_state[model] # Test backend availability is_training = True @@ -1459,9 +1541,9 @@ def test_sanity_attention_extra_state(model, dtype): if not fused_attn_supported and not flash_attn_supported: pytest.skip("No attention backend available.") - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( + outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state( dtype, config, mimic_v1_6=True, checkpoint=True ) @@ -1483,7 +1565,8 @@ def test_sanity_attention_extra_state(model, dtype): ) -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): +def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + """Run DotProductAttention module in FP8 with checkpointing""" steps = 10 path = "checkpoint.pt" fp8_enabled = True @@ -1580,7 +1663,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 16, 128), "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), @@ -1600,33 +1683,6 @@ def get_model(dtype, config): qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] -def _rmse(a, b): - return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) - - -def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): - logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) - logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) - try: - if a.dtype != b.dtype: - a = a.to(b.dtype) - torch.testing.assert_close(a, b, atol=atol, rtol=rtol) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert rmse < rmse_tol * rmse_range, ( - name_a - + " vs " - + name_b - + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - ) - - @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") @@ -1638,6 +1694,7 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): + """Test MultiHeadAttention module in FP8""" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] @@ -1691,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -1699,8 +1756,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) - _error( + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -1708,12 +1766,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) if is_training: for i in range(len(param_names[:1])): logging.debug("========== {:^25s} ==========".format(param_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -1721,10 +1780,12 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): + """Run MultiHeadAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1851,6 +1912,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): + """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] # TODO(cyang): think of another way to verify dropout results @@ -1920,7 +1982,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -1928,6 +1990,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) if config.dropout_p != 0.0: # test cuDNN FP8 dropout @@ -1935,7 +1998,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): fused_attn_fwd_fp8 == 1 ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." else: - _error( + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -1943,11 +2006,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) if is_training: for i, _ in enumerate(fused_attn_bwd_f16): logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -1955,11 +2019,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): - + """Run DotProductAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -2092,7 +2157,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_1": ModelConfig(1, 512, 1, 64), "fp8_2": ModelConfig(4, 512, 16, 64), "fp8_3": ModelConfig(1, 2048, 1, 128), @@ -2147,7 +2212,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 rmse_tol = 0.13 - _error( + compare_and_assert( fused_attn_fwd_fp8, unfused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -2155,8 +2220,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) - _error( + compare_and_assert( fused_attn_bwd_fp8, unfused_attn_bwd_f16, "fused_attn_bwd_fp8", @@ -2164,6 +2230,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 7078cb69d..c752d07d8 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -6,6 +6,7 @@ import subprocess import sys import pathlib +import logging import pytest import torch @@ -19,13 +20,15 @@ sys.path.append(str(_current_file.parent.parent)) from utils import ModelConfig, get_available_attention_backends +pytest_logging_level = logging.getLevelName(logging.root.level) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) model_configs_flash_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA @@ -72,6 +75,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") config = model_configs_flash_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": @@ -89,6 +94,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ) if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} + available_backends, *_ = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + flash_attn_supported, *_ = available_backends + if not flash_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( @@ -98,13 +112,14 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ), check=True, ) model_configs_fused_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig( @@ -135,6 +150,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA + "cp_4_0": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" + ), # GQA + "cp_4_1": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), # GQA + "cp_4_2": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), # GQA } @@ -158,6 +182,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("FP8 attention is only supported on sm90+!") config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": @@ -191,13 +217,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + if dtype == "fp8" and config.softmax_type != "vanilla": + pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip( + "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + ) + if config.softmax_type != "vanilla" and qkv_format == "thd": + pytest.skip( + "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + ) + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} available_backends, _, fused_attn_backends = get_available_attention_backends( config, - qkv_dtype=dtypes[dtype], + qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), - window_size=config.window_size, - context_parallel=True, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -212,6 +247,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, fp8_mha=fp8_mha, + log_level=pytest_logging_level, ), check=True, ) diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index 288c5382e..4dc3af411 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -469,7 +469,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=False, is_training=False, fp8=is_fp8, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 38f400f65..9e90f9fda 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -20,6 +20,7 @@ get_attention_backend, AttentionParams, AttentionLogging, + check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend @@ -137,6 +138,31 @@ def reset_rng_states() -> None: torch.cuda.set_rng_state(cuda_rng_state) +def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if not is_fp8: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + return + + try: + if a.dtype != b.dtype: + a = a.to(b.dtype) + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except Exception as e: + logging.debug(e) + + rmse = torch.sqrt((a - b).square().mean()).item() + logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert rmse < rmse_tol * rmse_range, ( + name_a + + " vs " + + name_b + + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + ) + + class ModelConfig: def __init__( self, @@ -147,12 +173,15 @@ def __init__( max_seqlen_kv: int = None, num_gqa_groups: int = None, head_dim_v: int = None, + softmax_type: str = "vanilla", dropout_p: float = 0.0, attn_mask_type: str = "no_mask", attn_bias_type: str = "no_bias", alibi_type: str = "none", bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + context_parallel: bool = False, + cp_comm_type: str = "p2p", total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -171,13 +200,16 @@ def __init__( self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.softmax_type = softmax_type self.dropout_p = dropout_p self.attn_mask_type = attn_mask_type self.attn_bias_type = attn_bias_type self.alibi_type = alibi_type self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape - self.window_size = window_size + self.window_size = check_set_window_size(self.attn_mask_type, window_size) + self.context_parallel = context_parallel + self.cp_comm_type = cp_comm_type self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -198,9 +230,7 @@ def get_available_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), pad_between_seqs: bool = False, - context_parallel: bool = False, deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, @@ -250,19 +280,21 @@ def test(): head_dim_qk=config.head_dim_qk, head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, - window_size=window_size, + window_size=config.window_size, alibi_slopes_shape=alibi_slopes_shape, core_attention_bias_type=config.attn_bias_type, core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_requires_grad=core_attention_bias_requires_grad, pad_between_seqs=pad_between_seqs, attention_dropout=config.dropout_p, - context_parallel=context_parallel, + context_parallel=config.context_parallel, + cp_comm_type=config.cp_comm_type, deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, inference_params=inference_params, + softmax_type=config.softmax_type, ) ( use_flash_attention, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 795697635..77cd8d235 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -135,9 +135,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -175,7 +176,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO (cyang): add is_training to nvte_get_fused_attn_backend // sm90: fwd d<=256, bwd d=128 only // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || @@ -183,7 +185,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && + !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { @@ -213,7 +215,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { flag_m512 = true; } if ( @@ -363,7 +366,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // check 64-bit ragged offset support (supported_ragged_offset_size) && // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && + // softmax type + // pre-9.13.1: vanilla + // 9.13.1+: vanilla, off-by-one, learnable + (cudnn_runtime_version >= 91301 || + (cudnn_runtime_version < 91301 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -405,14 +414,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, float attn_scale, + float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -421,6 +432,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); const Tensor *input_QKV = convertNVTETensorCheck(QKV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -447,8 +459,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -463,9 +475,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, + input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -487,10 +499,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); @@ -505,6 +518,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con Tensor *input_output_dP = convertNVTETensorCheck(dP); Tensor *output_dQKV = convertNVTETensorCheck(dQKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_QKV->data.shape.size(); @@ -529,8 +543,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -543,19 +557,22 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd_qkvpacked( b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, - input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, + input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, + stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -580,14 +597,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -600,6 +618,7 @@ void nvte_fused_attn_fwd_kvpacked( const Tensor *input_Q = convertNVTETensorCheck(Q); const Tensor *input_KV = convertNVTETensorCheck(KV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -660,8 +679,8 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -677,10 +696,11 @@ void nvte_fused_attn_fwd_kvpacked( fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -702,12 +722,12 @@ void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -723,6 +743,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor *output_dQ = convertNVTETensorCheck(dQ); Tensor *output_dKV = convertNVTETensorCheck(dKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -755,8 +776,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -770,20 +791,23 @@ void nvte_fused_attn_bwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, + input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, + output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, + handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -809,16 +833,17 @@ void nvte_fused_attn_bwd_kvpacked( } // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -832,6 +857,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_K = convertNVTETensorCheck(K); const Tensor *input_V = convertNVTETensorCheck(V); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -886,8 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -903,10 +929,11 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, + window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, + Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -928,14 +955,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -953,6 +981,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dK = convertNVTETensorCheck(dK); Tensor *output_dV = convertNVTETensorCheck(dV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_Q->data.shape.size(); @@ -978,8 +1007,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -993,19 +1022,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, - output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4e6c3c858..1d6435ad8 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -54,10 +54,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -98,8 +100,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_q = is_ragged_q ? max_t_q : s_q; s_kv = is_ragged_kv ? max_t_kv : s_kv; } - const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { FADescriptor_v1 descriptor{b, h, @@ -122,6 +124,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, true, @@ -138,6 +141,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // O std::shared_ptr, // Stats std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // page_table_k @@ -168,7 +172,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr Q, K, V, attn_scale, softmax_offset; std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr page_table_k, page_table_v; std::shared_ptr offset_q, offset_k, offset_v, offset_o, @@ -302,6 +306,15 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_sink_token(softmax_offset); + } + auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); std::vector o_stride(4); @@ -338,6 +351,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); auto Stats_tuple = std::make_tuple(Stats); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) @@ -358,17 +373,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, - page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_fprop_cache, descriptor); + auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv, + page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -473,6 +489,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } + + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -483,14 +504,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, + void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -506,6 +527,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -558,6 +580,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, @@ -579,6 +602,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // dV std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // offset_q @@ -608,7 +633,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_compute_data_type(fe::DataType_t::FLOAT); std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset, + seq_q, seq_kv; std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; @@ -771,6 +797,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_dsink_token(d_softmax_offset); + } + auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); @@ -796,6 +837,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr> // dV key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto offset_qo_tuple = @@ -814,17 +858,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, - offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_bprop_cache, descriptor); + auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, + d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -938,6 +982,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -949,8 +998,9 @@ using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -977,6 +1027,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -990,53 +1044,50 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( max_tokens = get_max_tokens(num_tokens); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1050,11 +1101,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, + nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), + workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1074,9 +1125,10 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1122,6 +1174,12 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; @@ -1135,11 +1193,11 @@ void fused_attn_arbitrary_seqlen_bwd_qkvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, - devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, + devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, + devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, + devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1161,12 +1219,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1192,6 +1250,10 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrO = output_O->data.dptr; void *devPtrS = nullptr; @@ -1216,53 +1278,50 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1277,11 +1336,11 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1302,10 +1361,11 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { @@ -1359,6 +1419,12 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1374,9 +1440,10 @@ void fused_attn_arbitrary_seqlen_bwd_kvpacked( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); @@ -1401,12 +1468,13 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1425,6 +1493,10 @@ void fused_attn_arbitrary_seqlen_fwd( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1446,53 +1518,50 @@ void fused_attn_arbitrary_seqlen_fwd( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_bias->data.dptr = nullptr; output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1507,11 +1576,11 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, + devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, + devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1532,13 +1601,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1577,6 +1647,12 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdV = output_dV->data.dptr; void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1592,9 +1668,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index e1a20274f..b9658b053 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -21,17 +21,19 @@ namespace transformer_engine { void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, + const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -41,21 +43,22 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, + Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -66,24 +69,26 @@ void fused_attn_arbitrary_seqlen_fwd( size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d7f098376..995dbda7f 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1695,6 +1695,7 @@ void fused_attn_fp8_fwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, true, @@ -2000,6 +2001,7 @@ void fused_attn_fp8_bwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, false, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 678b63691..0a0197423 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -107,6 +107,7 @@ struct FADescriptor_v1 { NVTE_QKV_Layout layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; + NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; bool deterministic; @@ -116,14 +117,15 @@ struct FADescriptor_v1 { bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, - window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, + bwd_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, - rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, - rhs.bwd_tensor_type); + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, + rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, + rhs.fwd_tensor_type, rhs.bwd_tensor_type); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 44f579149..a150978c4 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -124,6 +124,24 @@ enum NVTE_Mask_Type { NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, }; +/*! \enum NVTE_Softmax_Type + * \brief Attention softmax types as described in + * Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3). + * For a given attention score S = Q*K^T, different softmax types perform different operations on S, + * NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + * NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + * NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + * where alpha is a learnable parameter in shape [H]. + */ +enum NVTE_Softmax_Type { + /*! Vanilla softmax */ + NVTE_VANILLA_SOFTMAX = 0, + /*! Off-by-one softmax */ + NVTE_OFF_BY_ONE_SOFTMAX = 1, + /*! Learnable softmax */ + NVTE_LEARNABLE_SOFTMAX = 2, +}; + /*! \enum NVTE_Fused_Attn_Backend * \brief Fused attention backends */ @@ -178,6 +196,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] qkv_layout The layout of Tensors Q, K, V. * \param[in] bias_type The attention bias type. * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. * \param[in] dropout The dropout probability. * \param[in] num_attn_heads The number of heads in Q. * \param[in] num_gqa_groups The number of heads in K, V. @@ -190,9 +209,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); /*! \brief Compute dot product attention with packed QKV input. * @@ -224,6 +244,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -239,19 +260,19 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked( + const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, + bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -284,6 +305,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * e.g. M, ZInv, rng_state. * \param[out] dQKV The gradient of the QKV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, @@ -293,6 +315,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -302,10 +325,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); @@ -340,6 +364,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] Q The Q tensor, in HD layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -361,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -368,13 +394,15 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -409,6 +437,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dQ The gradient of the Q tensor. * \param[out] dKV The gradient of the KV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -422,6 +451,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -431,12 +461,12 @@ void nvte_fused_attn_fwd_kvpacked( void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -473,6 +503,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] K The K tensor. * \param[in] V The V tensor. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -494,22 +525,24 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream); + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -549,6 +582,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[out] dK The gradient of the K tensor. * \param[out] dV The gradient of the V tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -562,6 +596,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. @@ -571,14 +606,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 67d21f618..68b7aa8bb 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -36,6 +36,10 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) \ + .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \ + .value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \ + .value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \ pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 40089dc2d..9277569e1 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -18,10 +18,11 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right) { + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); return backend; } @@ -146,6 +147,9 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); @@ -172,28 +176,30 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), + dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, + kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, + query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); } @@ -262,10 +268,15 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; + auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -280,12 +291,12 @@ static void FusedAttnForwardImpl( if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), + o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -293,12 +304,13 @@ static void FusedAttnForwardImpl( auto q_tensor = TensorWrapper(q, q_shape, dtype); auto kv_tensor = TensorWrapper(k, kv_shape, dtype); nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), + q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -307,12 +319,13 @@ static void FusedAttnForwardImpl( auto k_tensor = TensorWrapper(k, k_shape, dtype); auto v_tensor = TensorWrapper(v, v_shape, dtype); nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -444,6 +457,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 min_num_segments = input_batch * max_segments_per_seq; } + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = @@ -453,37 +469,38 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), - nullptr); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), doutput_tensor.data(), s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, - query_workspace_tensor.data(), nullptr); + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -515,14 +532,17 @@ static void FusedAttnBackwardImpl( /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); @@ -540,10 +560,11 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -562,10 +583,11 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), + q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; @@ -586,11 +608,12 @@ static void FusedAttnBackwardImpl( s_tensor.data(), // not used for F16 s_tensor.data(), // not used for F16 &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, + kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, + workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index afa1bae63..4a60bd9fe 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -13,6 +13,7 @@ from packaging.version import Version as PkgVersion import torch +import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( SplitAlongDim, @@ -142,6 +143,7 @@ def __init__( attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -149,6 +151,7 @@ def __init__( self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number + self.softmax_type = softmax_type def mask_func(x, y): return ( @@ -185,6 +188,7 @@ def forward( core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -326,7 +330,21 @@ def forward( dtype=query_layer.dtype ) - # attention scores and attention mask [b, np, sq, sk] + # add attention sink to the last column: [b, np, sq, sk+1] + if self.softmax_type != "vanilla": + matmul_result = torch.cat( + [ + matmul_result, + softmax_offset.to(dtype=matmul_result.dtype).expand( + matmul_result.size(0), -1, matmul_result.size(2), -1 + ), + ], + dim=-1, + ) + attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False) + attn_mask_type = "arbitrary" + + # attention scores and attention mask softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( matmul_result, attention_mask, attn_mask_type, softmax_scale @@ -337,6 +355,10 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) + # remove attention sink: [b, np, sq, sk] + if self.softmax_type != "vanilla": + attention_probs = attention_probs[..., :-1] + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -917,6 +939,7 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, fused_attention_backend, @@ -925,6 +948,7 @@ def forward( fp8_meta, quantizers, deterministic, + softmax_offset, ): # pylint: disable=missing-function-docstring # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype @@ -997,8 +1021,10 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, ) if is_output_fp8: out_ret = out_fp8 @@ -1059,8 +1085,10 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, ) out_save = out_ret fp8_tensors = (None, None, None, None) @@ -1114,6 +1142,7 @@ def forward( ctx.qkv_layout = qkv_layout ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.softmax_type = softmax_type ctx.window_size = window_size ctx.fused_attention_backend = ( fused_attention_backend if ctx.fp8 else FusedAttnBackend["F16_arbitrary_seqlen"] @@ -1224,6 +1253,7 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, ) @@ -1287,42 +1317,17 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, ) - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - None, - None, - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) + d_bias = None + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + d_softmax_offset = None + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] return ( None, None, @@ -1336,7 +1341,8 @@ def backward(ctx, d_out): dq, dk, dv, - rest[0], + d_bias, + None, None, None, None, @@ -1351,6 +1357,7 @@ def backward(ctx, d_out): None, None, None, + d_softmax_offset, ) @@ -1390,6 +1397,7 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -1402,6 +1410,7 @@ def __init__( ) == "1" and get_device_compute_capability() == (9, 0) self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.softmax_type = softmax_type def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1453,6 +1462,7 @@ def forward( quantizers=None, pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -1603,6 +1613,8 @@ def forward( fp8_meta=fp8_meta, quantizers=quantizers, pad_between_seqs=pad_between_seqs, + softmax_type=self.softmax_type, + softmax_offset=softmax_offset, ) else: with self.attention_dropout_ctx(): @@ -1626,6 +1638,7 @@ def forward( qkv_layout, core_attention_bias_type, attn_mask_type, + self.softmax_type, window_size, None, # rng_gen fused_attention_backend, @@ -1634,6 +1647,7 @@ def forward( fp8_meta, quantizers, self.deterministic, + softmax_offset, ) # ...hd -> ...(hd) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 09384217c..2e4b6b617 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -46,6 +46,7 @@ _cu_seqlens_info_with_cp_cache = {} _seq_chunk_ids_cache_for_reordering_before_attn = {} _seq_chunk_ids_cache_for_reordering_after_attn = {} +_softmax_offset_chunk_ids_cache = {} def flash_attn_p2p_communicate( @@ -318,6 +319,55 @@ def flash_attn_a2a_communicate( return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +def flash_attn_a2a_communicate_softmax_offset( + tensor: torch.Tensor, + h_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Split/AllGather communication for softmax offset.""" + if tensor is None: + return None + + global _softmax_offset_chunk_ids_cache + device = tensor.device + if (cp_size, device) not in _softmax_offset_chunk_ids_cache: + chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device) + _softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids + else: + chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)] + + if before_attn: + # softmax_offset: split round-robin to CP ranks + # [1, h, 1, 1] -> [1, cp, h//cp, 1, 1] + shape = tensor.shape + tensor = tensor.view( + *shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :] + ) + rank = get_distributed_rank(cp_group) + output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank]) + output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :]) + else: + # d_softmax_offset: all-gather from all ranks to all ranks + # [1, h//cp, 1, 1] -> [1, h, 1, 1] + inp = tensor.view(-1) + output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device) + with torch.cuda.stream(cp_stream): + torch.distributed.all_gather_into_tensor( + output, + inp, + group=cp_group, + async_op=False, + ) + torch.cuda.current_stream().wait_stream(cp_stream) + output = output.view( + *tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :] + ) + return output + + def _get_cu_seqlens_info_with_cp( batch_size: int, max_seqlen: int, @@ -1854,7 +1904,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], @@ -2014,7 +2064,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv // 2, cu_seqlens_q_per_step[cp_size - i - 1], @@ -2171,7 +2221,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q // 2, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], @@ -2289,7 +2339,7 @@ def backward(ctx, dout): ) fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( + dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q_per_step[cp_size - i - 1], @@ -3122,7 +3172,7 @@ def backward(ctx, dout): dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] - dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, @@ -3283,6 +3333,8 @@ def forward( cp_stream, quantizers, use_flash_attn_3, + softmax_type, + softmax_offset, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3391,6 +3443,10 @@ def forward( q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) + if softmax_type != "vanilla": + softmax_offset = flash_attn_a2a_communicate_softmax_offset( + softmax_offset, 1, cp_size, cp_group, cp_stream, True + ) if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): q_f16, k_f16, v_f16 = q, k, v @@ -3430,6 +3486,8 @@ def forward( cu_seqlens_kv_padded=cu_seqlens_kv_padded, window_size=window_size, **fp8_meta_kwargs, + softmax_type=softmax_type, + softmax_offset=softmax_offset, ) if fp8: out = out._data @@ -3532,6 +3590,7 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.softmax_type = softmax_type ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer @@ -3695,7 +3754,7 @@ def backward(ctx, dout): dout_part, fake_dtype=dout_dtype, internal=True ) - dq, dk, dv, _ = fused_attn_bwd( + dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -3719,6 +3778,7 @@ def backward(ctx, dout): window_size=ctx.window_size, deterministic=ctx.deterministic, **fp8_meta_kwargs, + softmax_type=ctx.softmax_type, ) if ctx.fp8: dq = dq._data @@ -3763,6 +3823,17 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + d_bias = None + d_softmax_offset = None + if ctx.use_fused_attention: + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] + d_softmax_offset = flash_attn_a2a_communicate_softmax_offset( + d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + if ctx.fp8: dq = ctx.dQKV_quantizer.create_tensor_from_data( dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 @@ -3793,6 +3864,7 @@ def backward(ctx, dout): None, None, None, + d_bias, None, None, None, @@ -3803,6 +3875,7 @@ def backward(ctx, dout): None, None, None, + d_softmax_offset, ) @@ -3835,6 +3908,8 @@ def attn_forward_func_with_cp( quantizers=None, pad_between_seqs=False, use_flash_attn_3=False, + softmax_type="vanilla", + softmax_offset=None, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3911,23 +3986,23 @@ def attn_forward_func_with_cp( else: assert isinstance( cp_group, dist_group_type - ), f"Unsupported process group for CP communication type {cp_comm_type}!" + ), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!" assert qkv_format in [ "bshd", "sbhd", "thd", - ], f"QKV format of {qkv_format} is not supported with context parallelism!" + ], f"Context parallelism does not support {qkv_format=}!" assert ( qkv_format != "sbhd" or use_fused_attention - ), "FlashAttention does not support sbhd format!" + ), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!" assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( - """Attention bias is only supported with FusedAttention and "causal" """ - """or "no_mask" mask types!""" + "Context parallelism only supports attention bias with FusedAttention backend and" + " non-padding mask types!" ) assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" + ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) @@ -3935,13 +4010,28 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "The context parallel running configs cannot support sliding window attetnion!" + ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "The context parallel running configs cannot support MLA!" + ], "Context parallelism does not support MLA with {cp_comm_type=}!" + + if fp8 and fp8_meta is not None: + if fp8_meta["recipe"].fp8_dpa: + assert ( + softmax_type == "vanilla" + ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + assert ( + softmax_type == "vanilla" or use_fused_attention + ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + assert ( + softmax_type == "vanilla" or cp_comm_type == "a2a" + ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + assert ( + softmax_type == "vanilla" or qkv_format != "thd" + ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" args = [ is_training, @@ -3982,7 +4072,17 @@ def attn_forward_func_with_cp( args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] + args += [ + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + quantizers, + use_flash_attn_3, + softmax_type, + softmax_offset, + ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b35b87a83..f72cd6926 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -11,6 +11,7 @@ import logging import torch +from torch.nn.parameter import Parameter import transformer_engine_torch as tex from transformer_engine.pytorch.utils import get_cudnn_version @@ -168,6 +169,17 @@ class DotProductAttention(TransformerEngineBaseModule): softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -223,6 +235,7 @@ def __init__( cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -307,6 +320,20 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.softmax_type = softmax_type + if self.softmax_type == "vanilla": + self.softmax_offset = None + if self.softmax_type == "off-by-one": + self.softmax_offset = torch.zeros( + self.num_attention_heads // self.tp_size, device="cuda" + ) + if self.softmax_type == "learnable": + self.register_parameter( + "softmax_offset", + Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + get_rng_state_tracker=get_rng_state_tracker, + ) + attn_kwargs = { "attention_dropout": attention_dropout, "attention_dropout_ctx": attention_dropout_ctx, @@ -328,6 +355,7 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs, + softmax_type=self.softmax_type, ) self.unfused_attention = UnfusedDotProductAttention( @@ -335,6 +363,7 @@ def __init__( attention_type=attention_type, **attn_kwargs, layer_number=layer_number, + softmax_type=self.softmax_type, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -634,6 +663,7 @@ def forward( query_layer, num_gemms=3, allow_non_contiguous=True, + allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): @@ -922,6 +952,7 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: if qkv_format == "thd": pad_between_seqs = ( @@ -957,11 +988,13 @@ def forward( pad_between_seqs=pad_between_seqs, attention_dropout=self.attention_dropout, context_parallel=context_parallel, + cp_comm_type=self.cp_comm_type, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, inference_params=inference_params, + softmax_type=self.softmax_type, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1022,6 +1055,12 @@ def forward( ) # run attention + softmax_offset = ( + self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) + if self.softmax_offset is not None + else None + ) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( @@ -1071,7 +1110,6 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) - # checkpoint_core_attention=False if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, @@ -1101,6 +1139,7 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, ) return self.fused_attention( query_layer, @@ -1129,6 +1168,7 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, ) from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled @@ -1157,6 +1197,7 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, ) return self.unfused_attention( _alibi_cache, @@ -1173,5 +1214,6 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, ) return None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9b2b9a1ac..72c595e3f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -24,6 +24,7 @@ QKVLayout, AttnBiasType, AttnMaskType, + SoftmaxType, FusedAttnBackend, META_QKV, META_DQKV, @@ -206,6 +207,8 @@ class AttentionParams: Attention dropout. context_parallel: bool, default = `False` Whether context parallelism is used or not. + cp_comm_type: str, default = "p2p" + The communication type of context parallelism. deterministic: bool, default = `False` Whether to run `DotProductAttention` with determinism or not. is_training: bool, default = `True` @@ -216,6 +219,8 @@ class AttentionParams: The FP8 metadata tensor of `DotProductAttention`. inference_params: Optional[InferenceParams], default = `None` Inference-related parameters. See InferenceParams for details. + softmax_type: str, default = "vanilla" + The type of softmax operation. See DotProductAttention for details. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -237,11 +242,13 @@ class AttentionParams: pad_between_seqs: bool = False attention_dropout: float = 0.0 context_parallel: bool = False + cp_comm_type: str = "p2p" deterministic: bool = False is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None + softmax_type: str = "vanilla" def __eq__(self, other): """ @@ -308,11 +315,13 @@ def get_attention_backend( pad_between_seqs = attention_params.pad_between_seqs attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel + cp_comm_type = attention_params.cp_comm_type deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params + softmax_type = attention_params.softmax_type # Run config logger = logging.getLogger("DotProductAttention") @@ -565,6 +574,51 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FlashAttention 3 for dropout") use_flash_attention_3 = False + # Filter: Softmax type + # context_parallel | softmax_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # no | vanilla | All + # no | off-by-one | FusedAttention, UnfusedDotProductAttention + # no | learnable | FusedAttention, UnfusedDotProductAttention + # yes | vanilla | FusedAttention, FlashAttention + # yes | off-by-one | FusedAttention + # yes | learnable | FusedAttention + if softmax_type != "vanilla": + logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) + use_flash_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type + ) + use_unfused_attention = False + if qkv_format == "thd": + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", + softmax_type, + ) + use_unfused_attention = False + if context_parallel: + logger.debug( + "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" + " = %s", + softmax_type, + ) + use_unfused_attention = False + if cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False + # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends # ---------------------------------------------------------------------------------------------------- @@ -806,6 +860,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], attention_dropout, num_heads, num_gqa_groups, diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 5fd16bf1a..790d78c75 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -135,6 +135,17 @@ class MultiheadAttention(torch.nn.Module): For that, please use `get_qkv_layout` to gain the layout information. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -245,6 +256,7 @@ def __init__( qk_norm_before_rope: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -262,6 +274,7 @@ def __init__( self.return_bias = return_bias self.cp_size = 1 self.cp_rank = 0 + self.softmax_type = softmax_type kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -416,6 +429,7 @@ def __init__( tp_group=tp_group, layer_number=self.layer_number, attention_type=self.attention_type, + softmax_type=self.softmax_type, ) # Linear diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b9810bf86..df2f5d1ca 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -12,6 +12,7 @@ NVTE_QKV_Format, NVTE_Bias_Type, NVTE_Mask_Type, + NVTE_Softmax_Type, NVTE_Fused_Attn_Backend, ) from ..tensor.quantized_tensor import Quantizer @@ -86,6 +87,12 @@ "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, } +SoftmaxType = { + "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, + "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, +} + FusedAttnBackend = { "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, @@ -131,8 +138,10 @@ def fused_attn_fwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, + softmax_offset: torch.Tensor = None, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -197,6 +206,8 @@ def fused_attn_fwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -205,6 +216,9 @@ def fused_attn_fwd( rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + softmax_offset: torch.Tensor, default = None + softmax offset tensor in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. Returns ---------- @@ -286,6 +300,7 @@ def fused_attn_fwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, cu_seqlens_q, cu_seqlens_kv, @@ -300,6 +315,7 @@ def fused_attn_fwd( s_quantizer, o_quantizer, attn_bias, + softmax_offset, rng_gen, rng_elts_per_thread, ) @@ -333,6 +349,7 @@ def fused_attn_bwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), deterministic: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: @@ -398,6 +415,8 @@ def fused_attn_bwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -417,6 +436,9 @@ def fused_attn_bwd( d_bias: torch.Tensor, optional gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias + d_softmax_offset: torch.Tensor, optional + gradient tensor of softmax offset in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. """ if attn_scale is None: d = q.size(-1) @@ -454,6 +476,7 @@ def fused_attn_bwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, deterministic, cu_seqlens_q, diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4cb05725b..4edc6d81e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -73,28 +73,31 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread); + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d835a5c9..8179727e5 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -58,13 +58,14 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); return fused_attention_backend; } @@ -72,14 +73,15 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread) { + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread) { TensorWrapper te_Q, te_K, te_V, te_O, te_S; auto none = py::none(); @@ -181,6 +183,16 @@ std::vector fused_attn_fwd( DType::kInt32, nullptr, nullptr, nullptr); } + // softmax offset + TensorWrapper te_SoftmaxOffset; + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); + std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; + te_SoftmaxOffset = + makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); @@ -199,11 +211,11 @@ std::vector fused_attn_fwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -215,51 +227,52 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; output_tensors.push_back(o_python); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = - (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) - : rng_state; - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } + auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); + }; + // allocate memory for nvte_aux_tensor_pack.tensors + // f16_max512 : S [b, h, sq, skv] + // f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + size_t i = 0; + at::Tensor output_tensor; + // intermediate softmax tensor, S or M + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + // fp8 has an additional softmax stats tensor, ZInv + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + } + // rng_state + if (i < nvte_aux_tensor_pack.size) { + set_tensor_param(i++, rng_state); + } + // bias (optional) + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + set_tensor_param(i++, Bias.value()); + } + // softmax_offset (optional) + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + set_tensor_param(i++, SoftmaxOffset.value()); } // execute the kernel NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -274,9 +287,10 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, @@ -499,6 +513,15 @@ std::vector fused_attn_bwd( } } + // create dSoftmaxOffset in the same shape as SoftmaxOffset + at::Tensor dSoftmaxOffset; + TensorWrapper te_dSoftmaxOffset; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA); + dSoftmaxOffset = torch::empty({1, static_cast(h_q), 1, 1}, options); + te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset); + } + // create workspace TensorWrapper workspace; @@ -507,10 +530,10 @@ std::vector fused_attn_bwd( nvte_fused_attn_bwd( te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -523,16 +546,16 @@ std::vector fused_attn_bwd( nvte_fused_attn_bwd( te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, + attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], + window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {py_dQ, py_dK, py_dV, py::cast(dBias)}; + return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)}; } at::Tensor fa_prepare_fwd(at::Tensor qkvi) { diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0f2e3c4de..70366dabe 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -966,12 +966,13 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) + if not self.allow_different_data_and_param_types: + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.activation_dtype = dtype def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -1060,6 +1061,7 @@ def prepare_forward( inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. The context manager is needed because there isn't a way for a module to know @@ -1067,6 +1069,7 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ + self.allow_different_data_and_param_types = allow_different_data_and_param_types self.forwarded_at_least_once = True # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 89e43f845..8a032b2f5 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -191,6 +191,17 @@ class TransformerLayer(torch.nn.Module): and `DotProductAttention` modules. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -306,6 +317,7 @@ def __init__( qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, qk_norm_before_rope: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -362,6 +374,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.attn_input_format = attn_input_format + self.softmax_type = softmax_type self.name = name @@ -397,6 +410,7 @@ def __init__( "qkv_format": self.attn_input_format, "seq_length": seq_length, "micro_batch_size": micro_batch_size, + "softmax_type": self.softmax_type, } self.self_attention = MultiheadAttention( From 2db20a6f8218ee9c04044b5596a71ae4154d68d3 Mon Sep 17 00:00:00 2001 From: shengfangd Date: Tue, 23 Sep 2025 09:00:34 +0800 Subject: [PATCH 004/141] [QA] Add pytest xml report for all tests in qa folder that use pytest (#2169) * Add pytest xml report for debug unittest and onnx unittest, and remove the duplicated test line in qa/L0_pytorch_debug_unittest/test.sh --------- Signed-off-by: erindai --- qa/L0_pytorch_debug_unittest/test.sh | 19 ++++++++++--------- qa/L1_pytorch_distributed_unittest/test.sh | 4 ++-- qa/L1_pytorch_onnx_unittest/test.sh | 4 +++- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index b4bf0a024..7f19dda67 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -7,6 +7,8 @@ : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" # Config with the dummy feature which prevents nvinspect from being disabled. # Nvinspect will be disabled if no feature is active. @@ -20,17 +22,16 @@ pip uninstall -y nvdlfw-inspect pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 exit $FAIL diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7f061d222..19889946a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -47,9 +47,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 1486d5097..720aa79e2 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -7,5 +7,7 @@ pip3 install onnxruntime==1.20.1 pip3 install onnxruntime_extensions==0.13.0 : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py From a92a0ad294a750e9c3d26dc9677746daa94da8ee Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Tue, 23 Sep 2025 11:15:06 -0400 Subject: [PATCH 005/141] [JAX] Local-Amax for Current-Scaling (#2183) * Adding Amax Primitive and related args. Signed-off-by: Ming Huang * Enable local-amax for current-scaling and optionally run AR aross FSDP/TP/SP. Signed-off-by: Ming Huang * Adding doc for Amax Primitive. Signed-off-by: Ming Huang * Fix the function name conflict. Signed-off-by: Ming Huang * Modification as feedback suggested. Signed-off-by: Ming Huang * Fix errors from lint. Signed-off-by: Ming Huang * Fix the wrong amax-scope in the bwd. Signed-off-by: Ming Huang * Added more description for amax-scope Signed-off-by: Ming Huang * Fix the wrong attribute name. Signed-off-by: Ming Huang * Keep dim for AmaxCalcuation. Signed-off-by: Ming Huang * Remove keepDim and add shardy_rule Signed-off-by: Ming Huang * Fix shardy_rule Signed-off-by: Ming Huang * Remove extra-collective bytes from ref_coll_count due to local amax. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Signed-off-by: Ming-Xu Huang Co-authored-by: Phuong Nguyen --- tests/jax/test_distributed_layernorm.py | 2 - .../jax/cpp_extensions/activation.py | 12 +- transformer_engine/jax/cpp_extensions/base.py | 17 ++- .../jax/cpp_extensions/normalization.py | 41 ++++- .../jax/cpp_extensions/quantization.py | 142 +++++++++++++++++- transformer_engine/jax/dense.py | 14 +- transformer_engine/jax/layernorm_mlp.py | 8 +- 7 files changed, 213 insertions(+), 23 deletions(-) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index a777e2f43..f3296277c 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -76,8 +76,6 @@ def generate_collectives_count_ref( all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize ) other_bytes = 0 - if fp8_recipe == recipe.Float8CurrentScaling(): - allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction return generate_collectives_count( allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d0a4e58fb..9499b1624 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -26,7 +26,7 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_dbias, _quantize_dbias_impl +from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -979,6 +979,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -987,6 +988,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: If quantizer is None: @@ -1044,7 +1046,13 @@ def act_lu( activation_type=activation_type, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, + ) return out if isinstance(quantizer, DelayedScaleQuantizer): diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index cc8a07860..96b73909e 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -173,7 +173,7 @@ def shardy_sharding_rule(*args): _primitive_registry = {} -def register_primitive(cls): +def register_primitive(cls, outer_only=False): """ Register a JAX primitive and add it to the internal registry. """ @@ -186,13 +186,14 @@ def register_primitive(cls): def name_of_wrapper_p(): return cls.name + "_wrapper" - inner_p = core.Primitive(cls.name) - dispatch.prim_requires_devices_during_lowering.add(inner_p) - inner_p.multiple_results = cls.multiple_results - inner_p.def_impl(partial(xla.apply_primitive, inner_p)) - inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="cuda") - cls.inner_primitive = inner_p + if not outer_only: + inner_p = core.Primitive(cls.name) + dispatch.prim_requires_devices_during_lowering.add(inner_p) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform="cuda") + cls.inner_primitive = inner_p outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 351767e36..d265be398 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -27,7 +27,7 @@ NamedSharding, get_cudnn_version, ) -from .quantization import _quantize_dbias_impl +from .quantization import _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -880,6 +880,7 @@ def layernorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: """Layer normalization forward pass with optional quantization. @@ -893,6 +894,7 @@ def layernorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -952,7 +954,13 @@ def layernorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, + ) return out, mu, rsigma is_2x2x = quantizer.is_2x2x() @@ -1082,6 +1090,7 @@ def rmsnorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: """Root mean square normalization forward pass with optional quantization. @@ -1093,6 +1102,7 @@ def rmsnorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -1153,7 +1163,11 @@ def rmsnorm_fwd( quantizer=None, ) out, _ = _quantize_dbias_impl( - out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + out.data, + is_dbias=False, + quantizer=quantizer, + dq_dtype=x.dtype, + amax_scope=amax_scope, ) return out, rsigma @@ -1278,6 +1292,7 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, ): """Common wrapper for normalization forward pass. @@ -1294,6 +1309,7 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A tuple containing: @@ -1311,12 +1327,27 @@ def normalization_fwd( zero_centered_gamma is not supported if norm_type is 'rmsnorm'. """ if norm_type == "layernorm": - output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = layernorm_fwd( + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + ) elif norm_type == "rmsnorm": assert ( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" - output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + ) mu = None else: raise ValueError(f"{norm_type=} is not supported.") diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 895913d0a..98b9b7e78 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,6 +6,8 @@ from functools import reduce from typing import Tuple, Optional, Union import math +from enum import Enum + import jax import jax.numpy as jnp @@ -26,7 +28,12 @@ get_min_device_compute_capability, NamedSharding, ) -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp, + global_mesh_resource, + lax_paral_op, +) from ..quantize import ( ScaledTensor2x, ScaledTensor, @@ -526,6 +533,126 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" +class AmaxScope(Enum): + """ + Amax Scope Enum + """ + + LOCAL = 1 + TPSP = 2 + FSDP = 3 + + +class AmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning + """ + + name = "jax_local_amax" + multiple_results = False + impl_static_args = (1,) # amax_scope + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + ): + """ + amax calcuation abstract + """ + del amax_scope + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + return out_aval + + @staticmethod + def impl( + x, + amax_scope, + ): + """ + amax calcuation implementation + """ + del amax_scope + amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) + return amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del (amax_scope, arg_infos, result_infos) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding + + @staticmethod + def partition( + amax_scope, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + + def sharded_impl(x): + amax = AmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + ) + if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP + gmesh = global_mesh_resource() + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh) + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + + if amax_scope is AmaxScope.FSDP: # Run AR across FSDP + gmesh = global_mesh_resource() + amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + + return amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, amax_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(amax_scope, mesh, value_types, result_types): + """ + amax calcuation shardy_sharding_rule + """ + del amax_scope, mesh, result_types + prefix = "AmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_spec = (f"{prefix}_amax",) + return SdyShardingRule((input_spec,), (output_spec,)) + + +register_primitive(AmaxCalculationPrimitive, outer_only=True) + + def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): @@ -572,6 +699,7 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -628,7 +756,10 @@ def _quantize_dbias_impl( # until the tensor is dequantized (e.g. in the GEMM). amax = x.amax if amax is None: - amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) + amax = AmaxCalculationPrimitive.outer_primitive.bind( + x.data, + amax_scope=amax_scope, + ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale @@ -700,6 +831,7 @@ def quantize( x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -710,6 +842,7 @@ def quantize( flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. is None. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A ScaledTensor containing the quantized input tensor. @@ -718,6 +851,7 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + amax_scope=amax_scope, ) return out @@ -727,6 +861,7 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -737,6 +872,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + Returns: A tuple containing: @@ -750,6 +887,7 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + amax_scope=amax_scope, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8087159a3..dd7f5e0e8 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -15,6 +15,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .quantize import ( ScaledTensorFactory, ScalingMode, @@ -64,6 +65,7 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, quantizer_set: QuantizerSet = noop_quantizer_set, + using_global_amax_of_x: bool = False, ): """Perform dense layer transformation with optional quantization. @@ -77,6 +79,7 @@ def dense( bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract quantizer_set: QuantizerSet which contains quantizers for different tensor types + using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. Returns: Transformed output tensor @@ -93,6 +96,7 @@ def dense( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ) return output @@ -103,6 +107,7 @@ def dense( 3, 4, 5, + 7, ), ) def _dense( @@ -113,6 +118,7 @@ def _dense( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ): """Internal implementation of dense layer transformation with custom VJP. @@ -127,6 +133,7 @@ def _dense( input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix quantizer_set: QuantizerSet which contains quantizers for different tensor types + using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. Returns: Transformed output tensor @@ -139,6 +146,7 @@ def _dense( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ) return output @@ -151,6 +159,7 @@ def _dense_fwd_rule( input_axes, kernel_axes, quantizer_set, + using_global_amax_of_x, ): """Forward pass rule for dense layer transformation. @@ -175,6 +184,7 @@ def _dense_fwd_rule( x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -182,6 +192,7 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -212,7 +223,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad + contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -238,6 +249,7 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP, ) # GEMM NT diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801a..e3eaa53e1 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,6 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, @@ -272,13 +273,12 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + amax_scope=AmaxScope.TPSP, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, - flatten_axis=-2, - quantizer=ffn1_quantizer_set.kernel, + kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP ) # NN GEMM @@ -317,6 +317,7 @@ def _layernorm_mlp_fwd_rule( casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) # NN GEMM @@ -417,6 +418,7 @@ def _layernorm_mlp_bwd_rule( grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From 3f875fb57fcf2872d238f8c7cb199b171c424536 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 23 Sep 2025 15:10:46 -0400 Subject: [PATCH 006/141] [JAX] Restore Shardy Rule with CompoundFactor (#2167) * Rework shardy rules * WAR for compound factor=1 Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 34 +++--- transformer_engine/jax/cpp_extensions/gemm.py | 11 +- .../jax/cpp_extensions/normalization.py | 5 +- .../jax/cpp_extensions/quantization.py | 5 +- .../jax/quantize/scaling_modes.py | 106 ++++++++++-------- 5 files changed, 90 insertions(+), 71 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 9499b1624..a8c14a608 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -410,27 +410,28 @@ def shardy_sharding_rule( result_types, ): del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) + prefix = "ActLu_" + input_shape = value_types[0].shape + output_shape = input_shape[:-2] + input_shape[-1:] + # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 + output_shape, unique_var=prefix + "x", flatten_axis=-1 ) - x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) - out = (*x_axes[:-2], x_axes[-1]) - scale_inv = scale_rules.rowwise_rule + x_axes = scale_rules.input_spec + # Correct input spec with act dim + x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] + out = scale_rules.input_spec colwise_out = (prefix + "out_colwise",) colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: colwise_scale_inv = scale_rules.colwise_rule if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple( - multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) - ) + colwise_out = multidim_transpose(out, transpose_axis=-1) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule - # amax is always a unit tensor. amax = (prefix + "amax",) return SdyShardingRule( @@ -438,7 +439,8 @@ def shardy_sharding_rule( x_axes, ("…1",), ), - (out, colwise_out, scale_inv, colwise_scale_inv, amax), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), + **scale_rules.factor_sizes, ) @@ -883,26 +885,30 @@ def shardy_sharding_rule( result_types, ): del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" + prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 ) x_axes = scale_rules.input_spec dz_axes = (*x_axes[:-2], x_axes[-1]) out = x_axes + colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "scale_inv_colwise",) if is_2x: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) else: colwise_out = out + colwise_scale_inv = scale_rules.colwise_rule dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) return SdyShardingRule( (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 2acc3fb68..118000be7 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -712,7 +712,7 @@ def shardy_sharding_rule( del out_dtype, grad, use_split_accumulator del mesh, result_types - prefix = "GemmPrimitive_" + prefix = "Gemm_" warnings.warn( "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," @@ -746,13 +746,8 @@ def _generate_operand_rules(name, ndim, cdims): lhs_scale_specs = ("…1",) rhs_scale_specs = ("…2",) if scaling_mode.is_1d_block_scaling(): - # Shardy rules for MXFP8 scales cannot be related to the operands because of the - # global-unpadding and local-padding workflow. This can potentially insert expensive - # re-shards in the partition call later if the scales are not already sharded correctly. - lhs_scale_specs, rhs_scale_specs = map( - lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), - (lhs_specs, rhs_specs), - ) + lhs_scale_specs = lhs_specs + rhs_scale_specs = rhs_specs lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d265be398..3348c725b 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -581,9 +581,9 @@ def shardy_sharding_rule( result_types, ) - prefix = "NormFwdPrimitive_" + prefix = "NormFwd_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 + value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 ) x_axes = scale_rules.input_spec @@ -604,6 +604,7 @@ def shardy_sharding_rule( mu, rsigma, ), + **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 98b9b7e78..021af4c9d 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -495,9 +495,9 @@ def shardy_sharding_rule( ): del out_dtype, scale_dtype, is_outer, mesh, result_types - prefix = "BaseDBiasQuantizePrimitive_" + prefix = "DBiasQuantize_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), + value_types[0].shape, unique_var=prefix + "x", flatten_axis=flatten_axis, ) @@ -519,6 +519,7 @@ def shardy_sharding_rule( return SdyShardingRule( (x_axes, ("…1",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + **scale_rules.factor_sizes, ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e81a614f0..b7828e931 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -17,7 +17,7 @@ import operator import numpy as np -from jax.experimental.custom_partitioning import BATCHING +from jax.experimental.custom_partitioning import BATCHING, CompoundFactor from jax.tree_util import register_pytree_node_class import jax.numpy as jnp @@ -152,12 +152,15 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: @abstractmethod def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. @@ -232,12 +235,15 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. @@ -245,7 +251,7 @@ def get_shardy_sharding_rules( The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -323,20 +329,23 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization Returns: The Shardy rules for the scaling mode """ del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -562,52 +571,55 @@ def get_grouped_scale_shape( return (n_block_x * n_block_y,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization Returns: The Shardy rules for the scaling mode """ - del flatten_axis - input_spec = [f"{unique_var}{i}" for i in range(input_rank)] - rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] - colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] - - # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. - # Unfortunately, because Shardy rules are applied to the inner primitive, the - # only way to preserve the relationship is to lower unpadded scales to the - # underlying custom call and pad them in C++. Until that's implemented, the - # Shardy rules for block scales have to be completely disconnected from the - # Shardy rules for the tensor they belong to. - - # # We have to use two different factors in the two CompoundFactors because of Shardy - # # verifier requirements, even though they are the same. - # rowwise_var = unique_var - # colwise_var = f"{unique_var}_" - # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") - - # # The rowwise and colwise scale tensors should be sharded the same way as the input. - # # However, we need to adjust the dimensions where the block scaling factor applies. - # rowwise = input_spec.copy() - # rowwise[-1] = rowwise_var - - # colwise = input_spec.copy() - # colwise[flatten_axis - 1] = colwise_var - - # # This implementation needs to be updated for different block dims. - # assert self._block_dims == (1, 32) + input_rank = len(input_shape) + input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] + flatten_axis = (flatten_axis + input_rank) % input_rank + + # This implementation needs to be updated for different block dims. + assert self._block_dims == (1, 32) + + # We have to use two different factors in the two CompoundFactors because of Shardy + # verifier requirements, even though they are the same. + blocksizes = {} + colwise_var = f"{unique_var}_None" + rowwise_var = f"{unique_var}_None" + if not input_shape[-1] == 32: + rowwise_var = input_spec[-1] + "_compound" + input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") + blocksizes["blocksize_x"] = 32 + if not input_shape[flatten_axis - 1] == 32: + colwise_var = input_spec[flatten_axis - 1] + "_compound" + input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") + blocksizes["blocksize_y"] = 32 + + # The rowwise and colwise scale tensors should be sharded the same way as the input. + # However, we need to adjust the dimensions where the block scaling factor applies. + rowwise = input_spec.copy() + rowwise[-1] = rowwise_var + + colwise = input_spec.copy() + colwise[flatten_axis - 1] = colwise_var return QuantizeShardyRules( tuple(input_spec), tuple(rowwise), tuple(colwise), - {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, + blocksizes, ) @@ -697,18 +709,22 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return self._get_impl().get_quantize_layout(usage) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis=-1 + self, + input_shape, + unique_var, + flatten_axis=-1, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 From afd15a16891fdc5d0f3efeb21e44ab15b54634c2 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 24 Sep 2025 15:52:54 -0400 Subject: [PATCH 007/141] [JAX] Update JAX version requirement in pyproject.toml (#2197) update jax requirements Signed-off-by: Phuong Nguyen --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ef112d279..64ff4c5ce 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,8 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax[cuda12]", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", +"torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" From 9e727966f4505d6740372572f89facd9d01f4c40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 26 Sep 2025 20:26:29 +0200 Subject: [PATCH 008/141] [PyTorch] Unpin version of onnxscript and onnxruntime (#2202) * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski --- build_tools/pytorch.py | 2 +- qa/L1_pytorch_onnx_unittest/test.sh | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 33a3abfb7..a974e370d 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,7 +14,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] + return ["torch>=2.1", "einops", "onnxscript", "onnx"] def test_requirements() -> List[str]: diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 720aa79e2..7fce13a3d 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -3,8 +3,8 @@ # See LICENSE for license information. -pip3 install onnxruntime==1.20.1 -pip3 install onnxruntime_extensions==0.13.0 +pip3 install onnxruntime +pip3 install onnxruntime_extensions : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} From 4d1457865847a83bb3b4582149188160fedddf98 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 26 Sep 2025 22:39:21 -0400 Subject: [PATCH 009/141] [JAX] Fix XML filename in the L0_jax_uniitest (#2205) fix xml file name Signed-off-by: Phuong Nguyen --- qa/L0_jax_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index e4a3f4630..cb097d492 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" From d75bf43f2e6fdc01afdf96a91b09245dc3c4987f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Sat, 27 Sep 2025 12:45:24 -0400 Subject: [PATCH 010/141] [JAX] CollectiveGemm (#2166) * init cgemm + unit tests * UB bootstrap with NCCL, no MPI dependency * add NVLINK-P2P check + error message * skip tests if no NVLINK available * use std::vector to store ncclComm_t * update misuse of TP warning Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- build_tools/jax.py | 1 + examples/jax/collective_gemm/common.py | 245 ++++++++++ examples/jax/collective_gemm/conftest.py | 29 ++ .../jax/collective_gemm/run_test_cgemm.sh | 111 +++++ .../jax/collective_gemm/test_dense_grad.py | 214 ++++++++ examples/jax/collective_gemm/test_gemm.py | 206 ++++++++ .../test_layernorm_mlp_grad.py | 272 +++++++++++ qa/L0_jax_distributed_unittest/test.sh | 4 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 98 +++- .../userbuffers/userbuffers-host.cpp | 30 +- .../transformer_engine/comm_gemm_overlap.h | 25 + transformer_engine/common/util/logging.h | 10 + transformer_engine/jax/cpp_extensions/gemm.py | 458 ++++++++++++++++-- transformer_engine/jax/cpp_extensions/misc.py | 8 + transformer_engine/jax/csrc/extensions.h | 9 +- .../jax/csrc/extensions/cgemm_helper.cpp | 259 ++++++++++ .../jax/csrc/extensions/cgemm_helper.h | 189 ++++++++ .../jax/csrc/extensions/gemm.cpp | 140 +++++- transformer_engine/jax/csrc/extensions/misc.h | 26 + .../jax/csrc/extensions/pybind.cpp | 12 +- transformer_engine/jax/dense.py | 73 ++- transformer_engine/jax/flax/transformer.py | 1 + transformer_engine/jax/layernorm_mlp.py | 43 +- transformer_engine/jax/sharding.py | 19 + 24 files changed, 2385 insertions(+), 97 deletions(-) create mode 100644 examples/jax/collective_gemm/common.py create mode 100644 examples/jax/collective_gemm/conftest.py create mode 100644 examples/jax/collective_gemm/run_test_cgemm.sh create mode 100644 examples/jax/collective_gemm/test_dense_grad.py create mode 100644 examples/jax/collective_gemm/test_gemm.py create mode 100644 examples/jax/collective_gemm/test_layernorm_mlp_grad.py create mode 100644 transformer_engine/jax/csrc/extensions/cgemm_helper.cpp create mode 100644 transformer_engine/jax/csrc/extensions/cgemm_helper.h diff --git a/build_tools/jax.py b/build_tools/jax.py index 67efbf00f..1f9552eb6 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -87,4 +87,5 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, + libraries=["nccl"], ) diff --git a/examples/jax/collective_gemm/common.py b/examples/jax/collective_gemm/common.py new file mode 100644 index 000000000..da79b2137 --- /dev/null +++ b/examples/jax/collective_gemm/common.py @@ -0,0 +1,245 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Shared functions for the comm_overlap tests""" + +import jax.numpy as jnp +import numpy as np + + +# Add this after your existing imports +def dtype_tols(dtype, rtol=None, atol=None): + """Expected numerical tolerance for a data type.""" + # Return immediately if tolerances are fully specified + if rtol is not None and atol is not None: + return {"rtol": rtol, "atol": atol} + + # Default tolerances for common dtypes + if dtype in [jnp.float32, "float32"]: + return {"rtol": 1e-5, "atol": 1e-8} + elif dtype in [jnp.float16, "float16"]: + return {"rtol": 1e-3, "atol": 1e-6} + elif dtype in [jnp.bfloat16, "bfloat16"]: + return {"rtol": 1e-2, "atol": 1e-5} + else: + return {"rtol": 1e-5, "atol": 1e-8} + + +def assert_allclose( + actual, + desired, + rtol=None, + atol=None, + dtype=None, + **kwargs, +): + """Check if two tensors are close.""" + # Infer data type if needed + if dtype is None: + if isinstance(actual, float): + dtype = "float32" + else: + dtype = actual.dtype + + # Determine tolerances + tols = {} + if rtol is None or atol is None: + tols = dtype_tols(dtype) + if rtol is not None: + tols["rtol"] = rtol + if atol is not None: + tols["atol"] = atol + + # Cast tensors to fp32 + if not isinstance(actual, float): + actual = actual.astype(jnp.float32) + if not isinstance(desired, float): + desired = desired.astype(jnp.float32) + + # Check if tensors are close + np.testing.assert_allclose(actual, desired, **tols, **kwargs) + + +def assert_allclose_print_index(ref_output, gathered_output, rtol=1e-5, atol=1e-8): + if not jnp.allclose(ref_output, gathered_output, rtol=rtol, atol=atol): + diff = jnp.abs(ref_output - gathered_output) + mask = diff > (atol + rtol * jnp.abs(gathered_output)) + print(mask.astype(int)) + print(jnp.where(mask, diff, 0)) + + +# Shared constants for all tests +DP_AXIS = "data" +TPSP_AXIS = "tensor_sequence" +PARAMS_KEY = "params" + +# Shared functions for distributed testing +import argparse +import jax +from jax.experimental import mesh_utils +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap + +# Global flag to track if distributed has been initialized +_distributed_initialized = False + + +def _is_distributed_initialized(): + """Check if JAX distributed has been initialized.""" + return _distributed_initialized + + +def _initialize_distributed(args): + """Initialize JAX distributed with custom arguments.""" + global _distributed_initialized + + # Check if already initialized + if _distributed_initialized: + return + + if args.coordinator_address is None or args.num_processes is None or args.process_id is None: + raise ValueError( + "All distributed initialization arguments are required: " + "--coordinator-address, --num-processes, --process-id" + ) + if args.local_device_ids is None: + assert ( + args.num_devices_per_process is not None + ), "Either local_device_ids or num_devices_per_process must be provided" + # Calculate device range for this process + # Single process single device: each process gets one unique device + # Single process multiple devices: each process gets a unique range of devices + start_device = args.process_id * args.num_devices_per_process + device_range = range(start_device, start_device + args.num_devices_per_process) + global_device_ids_for_this_process = ",".join(map(str, device_range)) + else: + # Use explicitly provided global device IDs + global_device_ids_for_this_process = args.local_device_ids + args.num_devices_per_process = len(args.local_device_ids.split(",")) + + assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" + + print( + f"Initializing JAX distributed with coordinator={args.coordinator_address}, " + f"num_processes={args.num_processes}, process_id={args.process_id}" + ) + # Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process" + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=global_device_ids_for_this_process, + ) + + _distributed_initialized = True + jax.clear_caches() + jax.config.update( + "jax_use_shardy_partitioner", False + ) # CollectiveGEMM does not work with Shardy yet + + assert jax.local_device_count() == 1, ( + f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" + f" {jax.local_device_count()}" + ) + + devices_per_process = 1 + num_total_devices = args.num_processes + + print( + f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," + f" devices_per_process={devices_per_process}, process_id={args.process_id}" + ) + + collective_gemm_bootstrap( + num_total_devices=num_total_devices, + num_devices_per_process=devices_per_process, + process_id=args.process_id, + tensor_parallel_size=args.tensor_parallel_size, + ) + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +def _create_mesh(args): + """Create mesh configuration with proper validation.""" + num_gpu = args.num_processes * args.num_devices_per_process + assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices" + num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args) + + print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)") + + device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) + mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS)) + return mesh + + +def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): + """Create common argument parser for all collective GEMM tests.""" + parser = argparse.ArgumentParser(description=description) + + # Distributed initialization arguments + parser.add_argument( + "--coordinator-address", + type=str, + default=None, + help="Coordinator address for distributed initialization", + ) + parser.add_argument( + "--num-processes", + type=int, + default=None, + help="Number of processes for distributed initialization", + ) + parser.add_argument( + "--process-id", type=int, default=None, help="Process ID for distributed initialization" + ) + parser.add_argument( + "--local-device-ids", + type=str, + default=None, + help="Local device IDs for distributed initialization (comma-separated)", + ) + parser.add_argument( + "--num-devices-per-process", type=int, default=1, help="Number of devices per process" + ) + + # Test configuration arguments + parser.add_argument( + "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" + ) + parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") + parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") + parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") + parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") + parser.add_argument( + "--collective-type", + type=str, + default="all_gather", + choices=["all_gather", "reduce_scatter"], + help="Type of collective operation", + ) + parser.add_argument( + "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + ) + parser.add_argument( + "--enable-data-parallel", action="store_true", help="Enable data parallelism" + ) + parser.add_argument( + "--enable-result-check", action="store_true", default=True, help="Enable result checking" + ) + + return parser diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py new file mode 100644 index 000000000..83937971a --- /dev/null +++ b/examples/jax/collective_gemm/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for collective_gemm tests""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for collective_gemm tests""" + parser.addoption("--coordinator-address", action="store", default="localhost:12345") + parser.addoption("--num-processes", action="store", default=1) + parser.addoption("--process-id", action="store", default=0) + parser.addoption("--local-device-ids", action="store", default=None) + + +@pytest.fixture(autouse=True) +def distributed_args(request): + """Fixture for querying distributed initialization arguments""" + if request.cls: + request.cls.coordinator_address = request.config.getoption("--coordinator-address") + request.cls.num_processes = int(request.config.getoption("--num-processes")) + request.cls.process_id = int(request.config.getoption("--process-id")) + request.cls.local_device_ids = request.config.getoption("--local-device-ids") + request.cls.num_devices_per_process = ( + 1 + if request.cls.local_device_ids is None + else len(request.cls.local_device_ids.split(",")) + ) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh new file mode 100644 index 000000000..5bf7ccb59 --- /dev/null +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -0,0 +1,111 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +# Check if NVLINK is supported before running tests +echo "*** Checking NVLINK support***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? + +# Check if command failed OR output indicates no NVLINK +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform" + echo "Collective GEMM tests require NVLINK connectivity" + echo "SKIPPING all tests" + exit 0 +else + echo "NVLINK support detected" +fi + +# Define the test files to run +TEST_FILES=( +"test_gemm.py" +"test_dense_grad.py" +"test_layernorm_mlp_grad.py" +) + +echo +echo "*** Executing tests in examples/jax/collective_gemm/ ***" + +HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM + +# Run each test file across all GPUs +for TEST_FILE in "${TEST_FILES[@]}"; do + echo + echo "=== Starting test file: $TEST_FILE ..." + + # Clear PIDs array for this test file + PIDS=() + + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_FILE}_gpu_${i}.log" + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done + + # Wait for all processes to finish + wait + + # Check and print the log content from process 0 (now has log file thanks to tee) + if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE SKIPPED" + elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE FAILED" + HAS_FAILURE=1 + else + echo "... $TEST_FILE PASSED" + fi + + # Remove the log files after processing them + wait + rm ${TEST_FILE}_gpu_*.log +done + +wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + +exit $HAS_FAILURE diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py new file mode 100644 index 000000000..df2dd5618 --- /dev/null +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.dense import dense + +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOp, + CollectiveOpSet, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(collective_op): + if collective_op.is_all_gather: + input_axes = (DP_AXIS, TPSP_AXIS, None) + weight_axes = (None, TPSP_AXIS) + bias_axes = (TPSP_AXIS,) + output_axes = (DP_AXIS, None, TPSP_AXIS) + else: # RS + input_axes = (DP_AXIS, None, TPSP_AXIS) + weight_axes = (TPSP_AXIS, None) + bias_axes = (None,) + output_axes = (DP_AXIS, TPSP_AXIS, None) + return input_axes, weight_axes, bias_axes, output_axes + + +def _get_operand_sharding(mesh, collective_op): + input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes)) + weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes)) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes)) + return x_sharding, weight_sharding, bias_sharding + + +def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + output = dense( + x, + weight, + bias, + contracting_dims=((2,), (0,)), + input_axes=input_axes, + kernel_axes=weight_axes, + output_axes=output_axes, + collective_op_set=collective_op_set, + ) + return jnp.mean(output.astype(jnp.float32)) + + +def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + ) + + +def run_dense_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + + x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op) + ref_output, ref_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + noop_collective_op_set, + ) + output, sharded_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + collective_op_set, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveDenseGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather(self): + """Test Collective Dense Gradient with AllGather""" + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter(self): + """Test Collective Dense Gradient with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_dense_grad.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_dense_grad_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py new file mode 100644 index 000000000..307e4444e --- /dev/null +++ b/examples/jax/collective_gemm/test_gemm.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective GEMM test on multi-GPU with tensor parallelism + +This script uses custom distributed initialization with the following arguments: +- --coordinator-address: Coordinator address for distributed initialization +- --num-processes: Number of processes for distributed initialization +- --process-id: Process ID for distributed initialization +- --local-device-ids: Local device IDs for distributed initialization + +Example: + python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3 +""" +import unittest +import os +from functools import partial + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp +from transformer_engine.jax.sharding import MeshResource + + +def _get_operand_sharding(mesh, collective_op, is_with_dp): + + dp_axis = DP_AXIS if is_with_dp else None + if collective_op == CollectiveOp.ALL_GATHER: + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS)) + bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + else: # RS + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(None)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) +def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + collective_op=collective_op, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +def run_gemm_tests(args, mesh=None): + """Execute GEMM tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + print(f"Device mesh: {mesh}") + + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( + mesh, collective_op, args.enable_data_parallel + ) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + ref_output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=CollectiveOp.NONE, + output_sharding=output_sharding, + ) + output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=collective_op, + # CollectiveGEMM output should have a correct sharding without applying sharding constraint + output_sharding=None, + ) + assert ( + ref_output.sharding == output.sharding + ), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}" + gathered_ref_output = jax.lax.with_sharding_constraint( + ref_output, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_output = jax.lax.with_sharding_constraint( + output, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(gathered_ref_output) + jax.block_until_ready(gathered_output) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(gathered_ref_output, gathered_output) + + +class TestCollectiveGemmWithDP(unittest.TestCase): + """Collective GEMM with DP unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective GEMM test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather_with_dp(self): + """Test Collective GEMM with AllGather""" + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter_with_dp(self): + """Test Collective GEMM with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 5: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_gemm.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + sys.exit(1) + + args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args() + _initialize_distributed(args) + run_gemm_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py new file mode 100644 index 000000000..7bd6eb6a3 --- /dev/null +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.layernorm_mlp import layernorm_mlp + +from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOpSet, + CollectiveOp, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(): + input_1_axes = (DP_AXIS, TPSP_AXIS, None) + weight_1_axes = (None, None, TPSP_AXIS) + bias_axes_1 = (None, TPSP_AXIS) + input_2_axes = (DP_AXIS, None, TPSP_AXIS) + weight_2_axes = (TPSP_AXIS, None) + bias_axes_2 = (None,) + return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 + + +def _get_operand_sharding(mesh): + input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = ( + _get_logical_axes() + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes)) + weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes)) + bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1)) + weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes)) + bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2)) + return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding + + +def _mean_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + output = layernorm_mlp( + x, + gamma, + beta=None, + kernels=[weight_1, weight_2], + biases=[bias_1, bias_2], + norm_type="rmsnorm", + dot_1_input_axes=input_1_axes, + dot_2_input_axes=input_2_axes, + kernel_1_axes=weight_1_axes, + kernel_2_axes=weight_2_axes, + activation_type=("gelu",), + collective_op_sets=collective_op_sets, + ) + return jnp.mean(output) + + +def _value_and_grad_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + return jax.jit( + jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) + )( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + + +def run_layernorm_mlp_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split( + rng, 7 + ) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight_1 = jax.random.normal( + weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_in) + bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16) + weight_2 = jax.random.normal( + weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_out) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16) + gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt( + args.hidden_in + ) + collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER) + collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER) + collective_op_sets = (collective_op_set_1, collective_op_set_2) + noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + + with mesh, fp8_autocast( + enabled=False, + fp8_recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = ( + _get_operand_sharding(mesh) + ) + x_sharded = jax.device_put(x, x_sharding) + weight_1_sharded = jax.device_put(weight_1, weight_1_sharding) + bias_1_sharded = jax.device_put(bias_1, bias_1_sharding) + weight_2_sharded = jax.device_put(weight_2, weight_2_sharding) + bias_2_sharded = jax.device_put(bias_2, bias_2_sharding) + + input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes() + ref_output, ref_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + noop_collective_op_sets, + ) + output, sharded_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveLayerNormMLPGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_layernorm_mlp_grad(self): + """Test Collective Dense Gradient with AllGather""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_layernorm_mlp_grad.py --coordinator-address
" + " --num-processes --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_layernorm_mlp_grad_tests(args, mesh=None) diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index d9c46347f..ae45f398e 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -29,6 +29,10 @@ wait python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +wait + +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh" +wait if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec29e6e12..56369db27 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl #endif _comm_created = true; } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA( @@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); @@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { } } +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor @@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 1ce89c512..6c7bed55a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), comm->comm_intra); + // Check for NVLINK support before attempting IPC operations + if (comm->nvsize > 1) { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device)); + bool peer_access_available = false; + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + int can_access_peer; + cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i); + if (peer_result == cudaSuccess && can_access_peer) { + peer_access_available = true; + break; + } + } + } + if (!peer_access_available) { + free(tmp); + NVTE_ERROR( + "No peer-to-peer access available between GPUs. This platform does not support the " + "GPU-to-GPU " + "communication required for multi-GPU userbuffers. Consider using single-GPU mode."); + return 1; + } + } + for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], cudaIpcMemLazyEnablePeerAccess)); } } @@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; + printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 4d65e26ce..cffc411a0 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -67,6 +67,11 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python @@ -78,17 +83,26 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python @@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python @@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 941899b28..c2ce684c4 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,6 +12,8 @@ #include #include +#include "nccl.h" + #ifdef NVTE_WITH_CUBLASMP #include #endif // NVTE_WITH_CUBLASMP @@ -104,4 +106,12 @@ #endif // NVTE_WITH_CUBLASMP +#define NVTE_CHECK_NCCL(expr) \ + do { \ + const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ + if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ + NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ + } \ + } while (false) + #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 118000be7..e5fcdac3c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -6,8 +6,10 @@ import math import operator from collections.abc import Iterable -from typing import Tuple, Sequence, Union +from dataclasses import dataclass from functools import partial, reduce +from typing import Tuple, Sequence, Union +from enum import Enum import warnings import jax @@ -16,8 +18,13 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule -import transformer_engine_jax as tex -from transformer_engine_jax import get_num_compute_streams +from transformer_engine_jax import ( + get_num_compute_streams, + JAXX_Collective_Op, + get_device_compute_capability, + initialize_cgemm_communicator, + get_cgemm_num_max_streams, +) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -37,11 +44,19 @@ is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, ) -from ..sharding import global_mesh_resource -from .misc import get_padded_spec +from .misc import get_padded_spec, is_all_reduce_in_float32 +from ..sharding import ( + global_mesh_resource, + tpsp_axis_size, + dp_or_fsdp_axis_size, +) __all__ = [ + "CollectiveOp", + "CollectiveOpSet", + "collective_gemm_bootstrap", + "noop_collective_op_set", "gemm", "grouped_gemm", "gemm_uses_jax_dot", @@ -56,7 +71,7 @@ def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) >= 90: + if get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 @@ -152,6 +167,161 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +def collective_gemm_bootstrap( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams=3, + compute_stream_priority=0, + communication_stream_priority=0, + num_sm_for_communication=2, + use_ce=True, + aggregate_all_gather=False, +): + """Initialize NCCL communicators for Collective GEMM operations. + + This function sets up the distributed communication infrastructure needed for + tensor parallel collective GEMM operations. It supports two main scenarios: + + 1. **Multi-device per process**: TP domain = single process + - Each process manages multiple GPUs (num_devices_per_process > 1) + - TP group consists of GPUs within the same process + - Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4 + + 2. **Single device per process**: TP domain spans multiple processes + - Each process manages one GPU (num_devices_per_process = 1) + - TP group spans across multiple processes + - Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4 + + Args: + num_total_devices (int): Total number of ranks across all processes. + Must be divisible by num_devices_per_process. + num_devices_per_process (int): Number of GPUs per process. + - For multi-device: equals tp_size (e.g., 4 GPUs per process) + - For single-device: equals 1 (1 GPU per process) + process_id (int): Process identifier (0-based). + Must be in range [0, num_total_devices // num_devices_per_process). + tensor_parallel_size (int): Size of tensor parallel groups. + Must divide num_total_devices evenly. + num_max_streams (int, optional): Maximum number of CUDA streams for overlap. + Higher values enable more parallelism but use more GPU resources. Default: 3. + compute_stream_priority (int, optional): Priority for GEMM computation streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + communication_stream_priority (int, optional): Priority for NCCL communication streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + num_sm_for_communication (int, optional): Number of streaming multiprocessors + reserved for communication operations. Default: 2. + use_ce (bool, optional): Enable CUDA copy engines for memory transfers. + Can improve performance by offloading memory operations. Default: True. + aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations + into larger ones for better efficiency. Default: False. + + Raises: + AssertionError: If num_total_devices is not divisible by num_devices_per_process, + or if process_id is out of valid range. + AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now) + RuntimeError: If NCCL initialization fails or if configuration + is invalid (e.g., insufficient GPUs). + + Example: + # Basic initialization (single device per process) + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4 + ) + + # Advanced configuration with custom performance settings + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4, + num_max_streams=5, # More parallelism + compute_stream_priority=1, # Lower compute priority + communication_stream_priority=0, # Higher comm priority + num_sm_for_communication=4, # More SMs for communication + use_ce=True, # Enable copy engines + aggregate_all_gather=True # Aggregate small operations + ) + + Note: + This function must be called after JAX distributed initialization + and before any collective GEMM operations. Each process should call + this function with its own unique process_id. + """ + + assert ( + num_devices_per_process == 1 and jax.local_device_count() == 1 + ), "Only single device per process is supported at the moment!" + assert num_total_devices % num_devices_per_process == 0, ( + f"Invalid num_total_devices={num_total_devices}," + f" num_devices_per_process={num_devices_per_process}" + ) + assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + initialize_cgemm_communicator( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams, + compute_stream_priority, + communication_stream_priority, + num_sm_for_communication, + use_ce, + aggregate_all_gather, + ) + + +class CollectiveOp(Enum): + "Enum for Collective Type in Collective GEMM" + + NONE = JAXX_Collective_Op.NONE + ALL_GATHER = JAXX_Collective_Op.ALL_GATHER + REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER + + @property + def is_all_gather(self) -> bool: + """Check if AllGather""" + return self == CollectiveOp.ALL_GATHER + + @property + def is_reduce_scatter(self) -> bool: + """Check if ReduceScatter""" + return self == CollectiveOp.REDUCE_SCATTER + + @property + def is_none(self) -> bool: + """Check if None""" + return self == CollectiveOp.NONE + + +@dataclass(frozen=True) +class CollectiveOpSet: + """ + A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers. + """ + + forward: CollectiveOp + backward: CollectiveOp + + @staticmethod + def create(forward_collective_op: CollectiveOp): + """Create a set of CollectiveOp for forward and backward passes""" + if forward_collective_op.is_all_gather: + backward_collective_op = CollectiveOp.REDUCE_SCATTER + elif forward_collective_op.is_reduce_scatter: + backward_collective_op = CollectiveOp.ALL_GATHER + else: + backward_collective_op = CollectiveOp.NONE + return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op) + + +noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE) + + @partial(jax.jit, static_argnums=(1, 2)) def swizzled_scale(scale_inv, flatten_axis, is_colwise): "Swizzle scale_inv via JAX transpose ops" @@ -174,7 +344,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12) + impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 inner_primitive = None outer_primitive = None @@ -193,8 +363,12 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del use_split_accumulator + del use_split_accumulator, transpose_batch_sequence def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -238,7 +412,7 @@ def _dims_are_consecutive(dims): ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING - and not tex.is_non_nt_fp8_gemm_supported() + and not is_fp8_gemm_with_all_layouts_supported() ): assert not lhs_is_transposed and rhs_is_transposed, ( "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " @@ -263,6 +437,19 @@ def _dims_are_consecutive(dims): out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + # Adjust output shape for comm+GEMM overlap + if not collective_op.is_none and not is_outer: # Inner abstract + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + overlap_out_shape = list(out_shape).copy() + if collective_op.is_all_gather: + overlap_out_shape[1] *= tpsp_axis_size() + else: # RS + overlap_out_shape[sequence_dim] = ( + overlap_out_shape[sequence_dim] // tpsp_axis_size() + ) + assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}" + output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) + # Validate bias bias_shape = (0,) bias_dtype = out_dtype @@ -302,9 +489,12 @@ def _dims_are_consecutive(dims): pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) # Declare cuBLAS workspace + workspace_size = get_cublas_workspace_size_bytes() + if not collective_op.is_none: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. - workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace_size += 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return output, bias_grad, pre_gelu_out, workspace @@ -330,8 +520,12 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del out_dtype + del out_dtype, transpose_batch_sequence, sequence_dim, is_outer lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -350,6 +544,7 @@ def lowering( "fuse_gelu": fuse_gelu, "grad": grad, "use_split_accumulator": use_split_accumulator, + "collective_op": int(collective_op.value), } operand_output_aliases = {} @@ -378,6 +573,10 @@ def impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): if scaling_mode.is_1d_block_scaling(): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -396,7 +595,34 @@ def impl( lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) - outputs = GemmPrimitive.inner_primitive.bind( + # Alter lhs blocks so that CGEMM RS outputs correctly + if ( + collective_op.is_reduce_scatter + and not transpose_batch_sequence + and not is_outer + and not lhs.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = lhs.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = lhs.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + lhs = reordered.reshape(original_shape) + + (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, rhs, @@ -410,8 +636,39 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ) - return outputs[:-1] # discard workspace array + # Alter output blocks for CGEMM AG + if ( + collective_op.is_all_gather + and not transpose_batch_sequence + and not is_outer + and not output.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = output.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = output.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + output = reordered.reshape(original_shape) + + return [output, bias_grad, pre_gelu_out] @staticmethod def outer_impl( @@ -428,6 +685,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): return GemmPrimitive.impl( lhs, @@ -443,6 +704,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ) @staticmethod @@ -456,7 +721,12 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + collective_op, + transpose_batch_sequence, + sequence_dim, + is_outer, ): + del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims @@ -484,6 +754,10 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -492,6 +766,8 @@ def batcher( def _parse_operand_output_specs( arg_infos, contracting_dims, + transpose_batch_sequence, + collective_op, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -499,14 +775,12 @@ def _parse_operand_output_specs( # Ensure that tensor sequence parallelism is not used via setting tp_resource if gsr.tp_resource is not None: - for i in range(len(lhs_specs) - 1): - if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: - warnings.warn( - "Tensor sequence parallelism is detected as" - f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" - f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" - " tensor sequence parallelism to avoid potential issues." - ) + if gsr.tp_resource in lhs_specs: + warnings.warn( + "Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'" + " appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource" + " for tensor sequence parallelism to avoid potential issues." + ) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) @@ -528,10 +802,43 @@ def _parse_operand_output_specs( assert reduce_spec is None, "Multiple reduce dimension is detected!" reduce_spec = l + sequence_dim = None + + # Find sequence dimension in lhs_specs if tensor sequence parallel is enabled + # We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim + if collective_op.is_all_gather: + try: + tpsp_idx = lhs_specs.index(gsr.tpsp_resource) + except ValueError as exc: + raise ValueError( + f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}." + " Please check your sharding configuration." + ) from exc + sequence_dim = tpsp_idx + assert (sequence_dim == 1) ^ transpose_batch_sequence, ( + "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)" + " or (sequence_dim=0 and transpose_batch_sequence=True). Received:" + f" sequence_dim={sequence_dim}," + f" transpose_batch_sequence={transpose_batch_sequence}." + ) + + elif collective_op.is_reduce_scatter: + assert reduce_spec == gsr.tpsp_resource, ( + "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" + f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" + ) + sequence_dim = int(not transpose_batch_sequence) + if reduce_spec is not None: # Other non-reduce cdims (if exists) need to be unsharded lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) - rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) + # Only do AG Sequence dim if not Overlap + if collective_op.is_all_gather: + rhs_cspecs = tuple( + s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs + ) + else: + rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. @@ -551,13 +858,31 @@ def _parse_operand_output_specs( for spec in rhs_non_cspecs ) - # Non-contracting dims of LHS to be gathered along the SP axis. - # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for - # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. - lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) + # Only do AG Sequence dim if not Overlap + if not collective_op.is_all_gather: + # Non-contracting dims of LHS to be gathered along the SP axis. + # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for + # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. + lhs_non_cspecs = tuple( + None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs + ) out_specs = lhs_non_cspecs + rhs_non_cspecs + # Only do AG Sequence dim if not Overlap RS + if collective_op.is_all_gather: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :] + elif collective_op.is_reduce_scatter: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = ( + out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :] + ) + # specs = merge(cspecs, non_cspecs) lhs_specs, rhs_specs = map( lambda cdims, cspecs, non_cspecs: ( @@ -572,10 +897,14 @@ def _parse_operand_output_specs( bias_specs = tuple(list(rhs_non_cspecs).copy()) gelu_specs = tuple(list(out_specs).copy()) + if not collective_op.is_none: + assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), reduce_spec, + sequence_dim, ) @staticmethod @@ -587,6 +916,10 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, @@ -595,11 +928,16 @@ def infer_sharding_from_operands( out_dtype, scaling_mode, grad, + use_split_accumulator, + result_infos, + is_outer, + sequence_dim, ) - del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs( + arg_infos, contracting_dims, transpose_batch_sequence, collective_op + ) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -624,20 +962,29 @@ def partition( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, ): - del result_infos + del result_infos, is_outer, sequence_dim ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), reduce_spec, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + inferred_sequence_dim, + ) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + ) - # Assemble argument shardings - # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) @@ -686,11 +1033,19 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=inferred_sequence_dim, + is_outer=False, + collective_op=collective_op, ) - # All-Reduce GEMM output - if reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + if reduce_spec is not None and not collective_op.is_reduce_scatter: + if is_all_reduce_in_float32(): # For unittest only + outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype( + out_dtype + ) + else: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) return outputs @@ -705,12 +1060,22 @@ def shardy_sharding_rule( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, operand_types, result_types, ): del out_dtype, grad, use_split_accumulator - del mesh, result_types + del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer + + if not collective_op.is_none: + raise NotImplementedError( + "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" + " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + ) prefix = "Gemm_" @@ -792,6 +1157,8 @@ def _te_gemm( fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -800,6 +1167,7 @@ def _te_gemm( lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -859,6 +1227,10 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=-1, + is_outer=True, + collective_op=collective_op, ) @@ -1176,6 +1548,8 @@ def gemm( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, **kwargs, ) -> Tuple[jnp.ndarray, ...]: r"""General matrix multiplication with optional quantization. @@ -1209,8 +1583,11 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only - supported with TE's custom call to cuBLAS GEMM. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + transpose_batch_sequence: bool, default = False + Transpose the batch and sequence dimensions of the input tensor. + collective_op: CollectiveOp, default = CollectiveOp.NONE + Collective operation type for collective GEMM. Returns ------- @@ -1254,6 +1631,7 @@ def gemm( "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) + assert collective_op.is_none, "JAX GEMM does not support collective GEMM" return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) outputs = _te_gemm( @@ -1262,6 +1640,8 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op, **kwargs, ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 3bda37128..52f5edbf3 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -293,3 +293,11 @@ def duplicate_with_new_description(self, desc: str): Create a new NamedSharding with the same mesh and spec but with a new description. """ return NamedSharding(self.mesh, self.spec, desc=desc) + + +@functools.lru_cache(maxsize=1) +def is_all_reduce_in_float32(): + """ + Check if all-reduce is in float32 + """ + return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1" diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 59079fe3f..92937dd46 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -32,9 +33,6 @@ #include "transformer_engine/activation.h" #include "transformer_engine/multi_stream.h" -// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace -XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); - namespace transformer_engine { namespace jax { @@ -121,6 +119,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); @@ -134,4 +133,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); + #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp new file mode 100644 index 000000000..7082bfb03 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -0,0 +1,259 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "cgemm_helper.h" + +#include "common/util/system.h" +#include "nccl.h" + +namespace transformer_engine { +namespace jax { + +ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) { + ncclUniqueId unique_id; + + int tp_domain_id = get_tp_domain_id(); + bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0); + + pid_t pgid = getpgid(0); + + std::string base_path = getenv("NVTE_JAX_NCCL_FILE_PATH", "/tmp"); + std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) + + "_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) + + "_domain_" + std::to_string(tp_domain_id) + ".bin"; + + if (is_tp_leader) { + NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id)); + + // Write the ID to a temporary file + std::ofstream file(id_file, std::ios::binary); + NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file); + file.write(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + file.close(); + } else { + // Wait for the ID file to be created and read it + int attempts = 0; + const int max_attempts = 100; + while (attempts < max_attempts) { + std::ifstream file(id_file, std::ios::binary); + if (file.is_open()) { + file.read(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + if (file.gcount() == sizeof(ncclUniqueId)) { + file.close(); + break; + } + file.close(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + attempts++; + } + NVTE_CHECK(attempts < max_attempts, + "Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file); + } + + if (is_tp_leader) { + _nccl_id_file_name.push_back(id_file); + } + + return unique_id; +} + +void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size) { + // Validate inputs + NVTE_CHECK(num_devices_per_process == 1, + "num_devices_per_process must be == 1, got num_devices_per_process=", + num_devices_per_process); + NVTE_CHECK(num_total_devices >= 1, + "num_total_devices must be >= 1, got num_total_devices=", num_total_devices); + NVTE_CHECK( + num_total_devices % num_devices_per_process == 0, + "num_total_devices must be divisible by num_devices_per_process, got num_total_devices=", + num_total_devices, ", num_devices_per_process=", num_devices_per_process); + + // Validate TP size + NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size); + NVTE_CHECK(num_total_devices % tp_size == 0, + "num_total_devices must be divisible by tp_size, got num_total_devices=", + num_total_devices, ", tp_size=", tp_size); + + auto &handler = get(false); + handler.num_total_devices = num_total_devices; + handler.num_devices_per_process = num_devices_per_process; + handler.process_id = process_id; + handler.num_processes = num_total_devices / num_devices_per_process; + handler.tp_size = tp_size; + handler.tp_num_domains = num_total_devices / tp_size; + + // Initialize vectors with the correct size + handler.local_device_ids_within_process.resize(num_devices_per_process); + handler.local_device_ids_within_tp_domain.resize(num_devices_per_process); + handler.tp_domain_ids.resize(num_devices_per_process); + handler.global_device_ids.resize(num_devices_per_process); + handler.tp_comms.resize(num_devices_per_process); + + NVTE_CHECK(0 <= process_id && process_id < handler.num_processes, + "Invalid process_id=", process_id, ", which is out of range [0, ", + handler.num_processes, ")"); + + // Initialize local devices and calculate their global device IDs and TP topology + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + // Use the device that JAX has already assigned to this process + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + handler.local_device_ids_within_process[local_idx] = current_device; + handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx; + + // Calculate TP-related values for this device + int global_device_id = handler.global_device_ids[local_idx]; + if (num_devices_per_process == tp_size) { + // Scenario 1: Multi-device per process - TP domain = single process + handler.local_device_ids_within_tp_domain[local_idx] = local_idx; + handler.tp_domain_ids[local_idx] = process_id; + } else { + // Scenario 2: Single device per process - TP domain spans multiple processes + handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size; + handler.tp_domain_ids[local_idx] = global_device_id / tp_size; + } + } + + ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp"); + + NVTE_CHECK_NCCL(ncclGroupStart()); + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx])); + int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx]; + NVTE_CHECK_NCCL( + ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank)); + } + NVTE_CHECK_NCCL(ncclGroupEnd()); + + // Allocate device memory for barrier operations + NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int))); + + handler._initialize = true; + + // Bootstrap UB via creating a dummy CommOverlapP2PBase object + std::vector buffer_shape{1, 1}; + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, + JAXX_Collective_Op::ALL_GATHER); +} + +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag) { + auto &config = CgemmConfig::get(false); + config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); + auto &handler = CommunicatorHandler::get(false); + handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); +} + +int GetCgemmNumMaxStreams() { + auto &config = CgemmConfig::get(); + return config.num_max_streams; +} + +CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector buffer_shape, + DType dtype, + JAXX_Collective_Op collective_op) { + auto &comm_handler = CommunicatorHandler::get(); + auto &cgemm_config = CgemmConfig::get(); + + int device_idx = comm_handler.get_local_device_idx_for_current_device(); + int64_t plan_id = 0; + hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast(dtype), + static_cast(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx); + + auto it = plan_map.find(plan_id); + if (it != plan_map.end()) { + return it->second.get(); + } + + if (comm_handler.num_devices_per_process == comm_handler.tp_size) { + // Multi-device per process + } else if (comm_handler.num_devices_per_process == 1) { + // Single device per process + NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0, + "For single device per process, num_total_devices must be divisible by tp_size, " + "got num_total_devices=", + comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size); + } else { + NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=", + comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size, + ". Supported scenarios: " + "(1) num_devices_per_process == tp_size (multi-device per process), " + "(2) num_devices_per_process == 1 (single device per process)"); + } + + std::unique_ptr executor; + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), + cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, + cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, + cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + + CommOverlapCore *executor_ptr = executor.get(); + plan_map[plan_id] = std::move(executor); + return executor_ptr; +} + +void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + NVTE_CHECK_NCCL( + ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes, + void *input_buf, size_t input_bytes, ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + size_t expected_output_bytes = input_bytes * tp_size; + NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ", + expected_output_bytes, ", got ", output_bytes); + + NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) { + allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm comm) { + this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm); + }; + barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); }; +} + +CommunicatorHandler::~CommunicatorHandler() { + if (_initialize && !tp_comms.empty()) { + for (auto &comm : tp_comms) { + if (comm != nullptr) { + ncclCommDestroy(comm); + } + } + } + if (_device_barrier) cudaFree(_device_barrier); + + for (const auto &file_path : _nccl_id_file_name) { + std::remove(file_path.c_str()); + } +} + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h new file mode 100644 index 000000000..84b2b8154 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -0,0 +1,189 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ +#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "transformer_engine/comm_gemm_overlap.h" + +namespace transformer_engine { +namespace jax { + +// Configuration singleton for CGEMM parameters +class CgemmConfig { + public: + int num_max_streams; + int gemm_priority; + int comm_priority; + int num_comm_sm; + bool use_ce; + bool aggregate_ag; + + static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm, + bool _use_ce, bool _aggregate_ag) { + auto &config = get(false); + config._initialized = true; + config.num_max_streams = _num_max_streams; + config.gemm_priority = _gemm_priority; + config.comm_priority = _comm_priority; + config.num_comm_sm = _num_comm_sm; + config.use_ce = _use_ce; + config.aggregate_ag = _aggregate_ag; + } + + static CgemmConfig &get(bool is_initialized = true) { + static thread_local CgemmConfig instance; + NVTE_CHECK( + instance._initialized == is_initialized, + "CgemmConfig must be initialized before using it, got is_initialized=", is_initialized); + return instance; + } + + CgemmConfig(const CgemmConfig &) = delete; + CgemmConfig &operator=(const CgemmConfig &) = delete; + + private: + CgemmConfig() = default; + ~CgemmConfig() = default; + bool _initialized = false; +}; + +// Forward declaration +class CollectiveGemmPlanRegistry; + +// NCCL communicator handler for collective GEMM operations +// Support both single process single device AND single process multi device +// Two scenarios: +// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size) +// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1) +class CommunicatorHandler { + public: + int num_total_devices = -1; + int num_devices_per_process = -1; + int process_id = -1; + int num_processes = -1; + + int tp_size = -1; + int tp_num_domains = -1; + std::vector local_device_ids_within_tp_domain; + std::vector tp_domain_ids; + std::vector tp_comms; + + std::vector local_device_ids_within_process; + std::vector global_device_ids; + + int get_global_rank() const { + int device_idx = get_local_device_idx_for_current_device(); + return global_device_ids[device_idx]; + } + + void nccl_device_barrier_impl(ExtComm); + void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm); + + ncclComm_t get_comm_for_current_device() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_comms[device_idx]; + } + + int get_local_device_idx_for_current_device() const { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + for (int i = 0; i < num_devices_per_process; i++) { + if (local_device_ids_within_process[i] == current_device) { + return i; + } + } + NVTE_ERROR("Current CUDA device ", current_device, + " not found in local_device_ids_within_process"); + } + + int get_local_device_id_within_tp_domain() const { + int device_idx = get_local_device_idx_for_current_device(); + return local_device_ids_within_tp_domain[device_idx]; + } + + int get_tp_domain_id() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_domain_ids[device_idx]; + } + + int get_tp_num_domains() const { return tp_num_domains; } + + static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size); + + private: + ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); + + public: + static CommunicatorHandler &get(bool is_initialized = true) { + static CommunicatorHandler instance; + NVTE_CHECK(instance._initialize == is_initialized, + "CommunicatorHandler._initialize=", instance._initialize, + ", is_initialized=", is_initialized); + return instance; + } + + ExtAllgatherOp allgather_func; + ExtBarrierOp barrier_func; + + CommunicatorHandler(const CommunicatorHandler &) = delete; + CommunicatorHandler &operator=(const CommunicatorHandler &) = delete; + + private: + CommunicatorHandler(); + ~CommunicatorHandler(); + + bool _initialize = false; + int *_device_barrier = nullptr; + std::vector _nccl_id_file_name; +}; + +// Plan registry for caching collective GEMM executors +class CollectiveGemmPlanRegistry { + public: + static CollectiveGemmPlanRegistry &getInstance() { + static thread_local CollectiveGemmPlanRegistry instance; + return instance; + } + + CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, + JAXX_Collective_Op collective_op); + + private: + CollectiveGemmPlanRegistry() {} + CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete; + CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete; + + std::unordered_map> plan_map; +}; + +// Function declarations +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag); + +int GetCgemmNumMaxStreams(); + +} // namespace jax +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 06dded1d8..1467fa887 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -6,13 +6,19 @@ #include "transformer_engine/gemm.h" #include +#include +#include #include #include #include "../extensions.h" +#include "cgemm_helper.h" +#include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/string.h" #include "common/util/system.h" +#include "cuda_runtime.h" +#include "nccl.h" #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -66,12 +72,75 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + nvte_cublas_handle_init(); + + // Init UB buffer + if (collective_op != JAXX_Collective_Op::NONE) { + auto &comm_handler = CommunicatorHandler::get(); + std::vector lhs_shape = { + product(lhs.dimensions(), 0, lhs_axis_boundary), + product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())}; + std::vector rhs_shape = { + product(rhs.dimensions(), 0, rhs_axis_boundary), + product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())}; + + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + + std::vector buffer_shape{0, 0}; + DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + } + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, + collective_op); + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, + FFI::Bind() + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator") + .Attr("collective_op")); + Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, + JAXX_Collective_Op collective_op) { // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || @@ -83,16 +152,9 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); - // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, " - "expected ", - out_.numel(), " elements ", to_string_like(out_shape), " but got ", - output->element_count(), " elements ", to_string_like(output->dimensions())); // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; @@ -133,9 +195,62 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + + if (collective_op == JAXX_Collective_Op::NONE) { + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", + to_string_like(out_shape), " but got ", output->element_count(), " elements ", + to_string_like(output->dimensions())); + + nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), + rhs_transposed, lhs_transposed, grad, workspace_.data(), false, + use_split_accumulator, num_math_sm, stream); + } else { + std::vector buffer_shape{0, 0}; + DType buffer_dtype = out_dtype; + auto &comm_handler = CommunicatorHandler::get(); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + out_shape[0] = out_shape[0] * comm_handler.tp_size; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + out_shape[0] = out_shape[0] / comm_handler.tp_size; + } + auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, buffer_dtype, collective_op); + if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); + // Prepare the auxiliary buffer for the reduce-scattered GEMM output + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + + // Launch GEMM+RS + executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_, + pre_gelu_, workspace_, grad, false, use_split_accumulator, out_, + stream); + + } else if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + auto aux_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Empty + + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + // Copy the distributed LHS operand into the local chunk of the communication buffer + executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); + // Launch AG+GEMM + executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, stream); + } + } return ffi_with_cuda_error_check(); } @@ -161,7 +276,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") - .Attr("use_split_accumulator"), + .Attr("use_split_accumulator") + .Attr("collective_op"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb..c8fb713d7 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -87,5 +87,31 @@ constexpr struct Alignment { std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +template +void hash_combine(int64_t &seed, const T &v, Rest... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + +enum class JAXX_Collective_Op : int64_t { + NONE = 0, + ALL_GATHER = 1, + REDUCE_SCATTER = 2, +}; + +static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { + switch (op) { + case JAXX_Collective_Op::ALL_GATHER: + return CommOverlapType::AG; + break; + case JAXX_Collective_Op::REDUCE_SCATTER: + return CommOverlapType::RS; + break; + default: + NVTE_ERROR("Invalid Collective Op ", static_cast(op)); + break; + } +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index afbeb644c..06e2e2e00 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -5,6 +5,8 @@ ************************************************************************/ #include "../extensions.h" +#include "cgemm_helper.h" +#include "common/util/cuda_runtime.h" namespace transformer_engine { namespace jax { @@ -57,7 +59,7 @@ pybind11::dict Registrations() { // GEMM dict["te_gemm_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler), pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM @@ -84,6 +86,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); + m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); + m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) @@ -159,6 +163,12 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) .export_values(); + + pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) + .value("NONE", JAXX_Collective_Op::NONE) + .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER) + .value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER) + .export_values(); } } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index dd7f5e0e8..23df1a0ce 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -11,6 +11,7 @@ from typing import Tuple, Sequence from functools import partial +import warnings import jax import jax.numpy as jnp @@ -62,10 +63,13 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + batch_sequence_transpose: bool = False, input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, - quantizer_set: QuantizerSet = noop_quantizer_set, + output_axes: Tuple[str, ...] = None, using_global_amax_of_x: bool = False, + collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, + quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -78,12 +82,20 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract - quantizer_set: QuantizerSet which contains quantizers for different tensor types + batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + input_axes: Logical axes for sharding the activation input + kernel_axes: Logical axes for sharding the weight matrix + output_axes: Logical axes for sharding the output using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. + collective_op_set: A set of CollectiveOp objects for forward and backward passes. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ + if batch_sequence_transpose: + warnings.warn("batch_sequence_transpose is not well tested, use with caution!") + if not get_quantize_config().is_fp8_enabled(): input_dtype = x.dtype kernel = kernel.astype(input_dtype) @@ -93,32 +105,30 @@ def dense( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, ) return output -@partial( - jax.custom_vjp, - nondiff_argnums=( - 3, - 4, - 5, - 7, - ), -) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) def _dense( x, kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, # need to be a diff_arg for DelayedScaling state management ): """Internal implementation of dense layer transformation with custom VJP. @@ -130,10 +140,13 @@ def _dense( kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification + batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input + output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix - quantizer_set: QuantizerSet which contains quantizers for different tensor types using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. + collective_op_set: A set of CollectiveOp objects for forward and backward passes. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor @@ -143,10 +156,13 @@ def _dense( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, ) return output @@ -156,10 +172,13 @@ def _dense_fwd_rule( kernel, bias, contracting_dims, + batch_sequence_transpose, input_axes, kernel_axes, - quantizer_set, + output_axes, using_global_amax_of_x, + collective_op_set, + quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -202,9 +221,12 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set.forward, ) + output = with_sharding_constraint_by_logical_axes(output, output_axes) if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape @@ -223,8 +245,16 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, using_global_amax_of_x, ctx, grad -): # pylint: disable=unused-argument + contracting_dims, + batch_sequence_transpose, + input_axes, + kernel_axes, + output_axes, + using_global_amax_of_x, + collective_op_set, + ctx, + grad, +): """Backward pass rule for dense layer transformation. Returns: @@ -239,6 +269,7 @@ def _dense_bwd_rule( quantizer_set, flatten_axis_k, ) = ctx + grad = with_sharding_constraint_by_logical_axes(grad, output_axes) fwd_x_contracting_dims, fwd_k_contracting_dims = map( tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims @@ -266,8 +297,9 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set.backward, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -279,7 +311,10 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), + transpose_batch_sequence=batch_sequence_transpose, ) + + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9a..ad66684f2 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ return drop_path_shape +# TODO(Phuong): move this function to sharding.py def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: """ Extend the given Flax logical axis rules with the predefined TransformerLayer's diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index e3eaa53e1..cf77f8e0a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -41,6 +41,7 @@ def layernorm_mlp( norm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6, + batch_sequence_transpose: bool = False, norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, @@ -49,6 +50,10 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + collective_op_sets: Tuple[tex.CollectiveOpSet] = ( + tex.noop_collective_op_set, + tex.noop_collective_op_set, + ), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -72,6 +77,7 @@ def layernorm_mlp( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + batch_sequence_transpose: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication @@ -80,6 +86,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -122,6 +129,7 @@ def layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -130,12 +138,13 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -147,6 +156,7 @@ def _layernorm_mlp( norm_type: str, zero_centered_gamma: bool, epsilon: float, + batch_sequence_transpose: bool, norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], @@ -155,6 +165,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -174,12 +185,16 @@ def _layernorm_mlp( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + batch_sequence_transpose: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for layernorm sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding + kernel_1_axes: Logical axes for first weight matrix sharding + kernel_2_axes: Logical axes for second weight matrix sharding ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of quantizer sets Returns: @@ -196,6 +211,7 @@ def _layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -204,6 +220,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ) return output @@ -220,6 +237,7 @@ def _layernorm_mlp_fwd_rule( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -228,6 +246,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -247,6 +266,10 @@ def _layernorm_mlp_fwd_rule( del kernel_1_axes, kernel_2_axes ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.forward.is_reduce_scatter + assert not collective_op_set_2.forward.is_all_gather # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) @@ -287,8 +310,10 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_1.forward, ) if use_bias_1 and tex.gemm_uses_jax_dot(): @@ -326,8 +351,10 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_2.forward, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -335,6 +362,8 @@ def _layernorm_mlp_fwd_rule( bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) + # sharding of outputs should be the same as dot_1's input + dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) ctx = ( @@ -364,6 +393,7 @@ def _layernorm_mlp_bwd_rule( norm_type, zero_centered_gamma, epsilon, + batch_sequence_transpose, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -372,6 +402,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + collective_op_sets, ctx, grad, ): @@ -410,6 +441,10 @@ def _layernorm_mlp_bwd_rule( ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.backward.is_all_gather + assert not collective_op_set_2.backward.is_reduce_scatter # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) @@ -436,6 +471,8 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set_2.backward, ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -450,6 +487,7 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -476,6 +514,8 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + transpose_batch_sequence=batch_sequence_transpose, + collective_op=collective_op_set_1.backward, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -486,6 +526,7 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=batch_sequence_transpose, ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 339e74e2f..7a8261269 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from typing import Callable, Optional import warnings + import jax import jax.numpy as jnp from jax.interpreters import pxla @@ -364,3 +365,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x + + +def tpsp_axis_size(): + """ + Get the size of the tensor parallelism axis. + Return 1 if no TP axis is set. + """ + return get_mesh_axis_size(global_mesh_resource().tpsp_resource) + + +def dp_or_fsdp_axis_size(): + """ + Get the size of the data parallelism or FSDP axis. + Return 1 if no DP/FSDP axis is set. + """ + dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) + fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) + return dp_size if dp_size > 1 else fsdp_size From a91e4585523f77a89cd41f12f3c869ee73572045 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 29 Sep 2025 11:35:34 -0400 Subject: [PATCH 011/141] [JAX] Add xml export for `test_multiprocessing_encoder` and `test_cgemm` (#2210) * add xml export for test_multiprocessing_encoder and test_cgemm Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/collective_gemm/run_test_cgemm.sh | 12 +++- .../run_test_multiprocessing_encoder.sh | 61 ++++++++++++++++--- 2 files changed, 63 insertions(+), 10 deletions(-) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh index 5bf7ccb59..af263eb53 100644 --- a/examples/jax/collective_gemm/run_test_cgemm.sh +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -4,6 +4,10 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + # Check if NVLINK is supported before running tests echo "*** Checking NVLINK support***" NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) @@ -69,7 +73,8 @@ for TEST_FILE in "${TEST_FILES[@]}"; do # For process 0: show live output AND save to log file using tee echo "=== Live output from process 0 ===" pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ --num-processes=$NUM_GPUS \ --process-id=$i 2>&1 | tee "$LOG_FILE" & PID=$! @@ -94,8 +99,11 @@ for TEST_FILE in "${TEST_FILES[@]}"; do elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then echo "... $TEST_FILE FAILED" HAS_FAILURE=1 - else + elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then echo "... $TEST_FILE PASSED" + else + echo "... $TEST_FILE INVALID" + HAS_FAILURE=1 fi # Remove the log files after processing them diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 2a1ac0f8f..2a979e177 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -15,11 +15,37 @@ TEST_CASES=( "test_te_current_scaling_fp8_shardy" ) +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + echo echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM # Run each test case across all GPUs for TEST_CASE in "${TEST_CASES[@]}"; do echo @@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Define output file for logs LOG_FILE="${TEST_CASE}_gpu_${i}.log" - # Run pytest and redirect stdout and stderr to the log file - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ - --num-process=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - done + # For process 0: show live output AND save to log file using tee + if [ $i -eq 0 ]; then + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \ + "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done # Wait for the process to finish wait - tail -n +7 "${TEST_CASE}_gpu_0.log" # Check and print the log content accordingly if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" + HAS_FAILURE=1 elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE PASSED" else + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 - echo "... $TEST_CASE FAILED" fi # Remove the log file after processing it @@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do done wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + exit $HAS_FAILURE From dfeef1a26ba48ccbd690567a19137b2af8aeb7c9 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Mon, 29 Sep 2025 13:39:03 -0700 Subject: [PATCH 012/141] [JAX] Address tolerance check for current scaling dact dbias (#2211) Address tolerance check for current scaling dact Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9e39b84c0..7f15eec89 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -780,9 +780,15 @@ def _test_quantize_dact_dbias( assert_allclose(te_output.data, jax_output.data) if is_dbias: - # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. precise_comparison = not ( - in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) + # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. + or ( + activation_type == ("squared_relu",) + and in_dtype == jnp.bfloat16 + and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + ) ) assert_allclose( te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype From 3f5b47549567d13db76470073c8f0467c23d4fca Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 29 Sep 2025 14:12:26 -0700 Subject: [PATCH 013/141] [Core][PyTorch] NVFP4 recipe (#2177) * Add NVFP4 recipe Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Frank Sun Co-authored-by: Oleg Goncharov Co-authored-by: Zhongbo Zhu Co-authored-by: Evgeny Tsykunov Co-authored-by: Tim Moon Co-authored-by: Teddy Do * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add MathDx dependency to GitHub builds Signed-off-by: Tim Moon * Suggestions from GitHub Copilot Signed-off-by: Tim Moon * Move 2x shape logic from core to PyTorch Signed-off-by: Kirthi Shankar Sivamani * Fix compilation errors with CUDA 12.1 Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * SM 70 is not supported in CUDA 13 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Typo Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Revert "Move 2x shape logic from core to PyTorch" This reverts commit f8b2a2d0111d9af690b43bb98ae448d9a430a185. Signed-off-by: Tim Moon * Added dequantize kernel for FP4 Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warning Signed-off-by: Tim Moon * Add NVFP4 support with fusible ops Use logical tensor dims for PyTorch NVFP4 tensors. Temporarily add unfused dequantize impl. Fix bug where NVFP4 recipe was not configurable. Signed-off-by: Tim Moon * Fix logic for 2x shapes and move to PyTorch Signed-off-by: Kirthi Shankar Sivamani * Fix CG test model config Signed-off-by: Kirthi Shankar Sivamani * Debug NVFP4 tensor size function Signed-off-by: Tim Moon * Proper handling of the RNG state Signed-off-by: Przemek Tredak * Test SR properly Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix workspace size for GEMM heuristic. Signed-off-by: Kirthi Shankar Sivamani * Fix compile error in C++ NVFP4 test Some some numeric errors when blocks are all zero. Signed-off-by: Tim Moon * fix distrbuted test problem shape Signed-off-by: zhongboz * proper assert dim for low precision AG TP Signed-off-by: zhongboz * clean up duplicated code in nvfp4_utils.cuh Signed-off-by: zhongboz * lint Signed-off-by: zhongboz * pylint: disable=unused-argument Signed-off-by: zhongboz * `nvte_cublas_gemm_v2` to take alpha pointer (#12) * make nvte_cublas_gemm_v2 to take alpha/beta pointers Signed-off-by: Phuong Nguyen * users are expected to pass a valid C_tensor Signed-off-by: Phuong Nguyen * typos Signed-off-by: Phuong Nguyen * API to have const float* alpha Signed-off-by: Phuong Nguyen * Minor tweaks Support arbitrary beta scales. Increase workspace to be aligned to 128 bytes. Signed-off-by: Tim Moon * Debug IMA with alpha pointer Signed-off-by: Tim Moon --------- Signed-off-by: Phuong Nguyen Signed-off-by: Tim Moon Co-authored-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support fused amax kernels with NVFP4 quantization Signed-off-by: Tim Moon * Disable fused amax with cuDNN LayerNorm kernel Signed-off-by: Tim Moon * Add NVFP4 cases to distributed tests for TE ops Signed-off-by: Tim Moon * Change assert to NVTE_CHECK in the hadamard cast fusion Signed-off-by: Przemek Tredak * Fix compile error Signed-off-by: Tim Moon * Use global thread IDs for Philox subsequences Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add shape checks for NVFP4 cast kernels Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Do not fuse amax if cuDNN normalization is forced by envvar Signed-off-by: Przemek Tredak --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Przemek Tredak Signed-off-by: zhongboz Signed-off-by: Phuong Nguyen Co-authored-by: Frank Sun Co-authored-by: Oleg Goncharov Co-authored-by: Zhongbo Zhu Co-authored-by: Evgeny Tsykunov Co-authored-by: Tim Moon Co-authored-by: Teddy Do Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Przemek Tredak Co-authored-by: Phuong Nguyen --- .github/workflows/build.yml | 8 +- benchmarks/benchmark_rht_cast.py | 152 ++ build_tools/utils.py | 15 +- pyproject.toml | 3 +- qa/L0_pytorch_unittest/test.sh | 1 + qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/cpp/operator/CMakeLists.txt | 8 + tests/cpp/operator/test_cast_mxfp8.cu | 42 +- .../operator/test_cast_mxfp8_gated_swiglu.cu | 54 +- .../cpp/operator/test_cast_nvfp4_transpose.cu | 741 ++++++++ tests/cpp/test_common.cu | 239 ++- tests/cpp/test_common.h | 37 +- tests/pytorch/distributed/run_numerics.py | 242 ++- .../pytorch/distributed/run_numerics_exact.py | 718 ++++++++ tests/pytorch/distributed/test_fusible_ops.py | 18 +- tests/pytorch/distributed/test_numerics.py | 7 +- .../distributed/test_numerics_exact.py | 70 + tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 243 +++ .../pytorch/nvfp4/test_nvfp4_module_exact.py | 559 ++++++ .../nvfp4/test_nvfp4_quantize_exact.py | 495 ++++++ .../nvfp4/test_nvfp4_rht_quantize_exact.py | 255 +++ tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 238 +++ tests/pytorch/test_cuda_graphs.py | 71 +- .../test_float8_current_scaling_exact.py | 9 +- tests/pytorch/test_fusible_ops.py | 121 +- tests/pytorch/test_recipe.py | 37 + tests/pytorch/test_sanity.py | 34 + tests/pytorch/utils.py | 25 +- transformer_engine/common/CMakeLists.txt | 30 +- transformer_engine/common/common.cu | 12 +- transformer_engine/common/common.h | 51 +- transformer_engine/common/gemm/config.cpp | 116 ++ transformer_engine/common/gemm/config.h | 36 + .../common/gemm/cublaslt_gemm.cu | 345 +++- .../hadamard_transform/hadamard_transform.cu | 876 ++++++++++ .../hadamard_transform_cast_fusion.cu | 841 +++++++++ .../common/include/transformer_engine/gemm.h | 189 +- .../transformer_engine/hadamard_transform.h | 68 + .../include/transformer_engine/recipe.h | 4 + .../transformer_engine/transformer_engine.h | 50 +- .../common/normalization/layernorm/ln_api.cpp | 4 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +- transformer_engine/common/recipe/__init__.py | 114 +- .../common/recipe/current_scaling.cu | 27 +- transformer_engine/common/recipe/nvfp4.cu | 54 + transformer_engine/common/swizzle/swizzle.cu | 265 +-- .../common/transformer_engine.cpp | 86 +- .../common/transpose/cast_transpose.h | 9 + ...quantize_transpose_vector_blockwise_fp4.cu | 842 +++++++++ .../common/util/cast_gated_kernels.cuh | 5 +- .../common/util/cast_kernels.cuh | 807 ++++++++- .../common/util/dequantize_kernels.cuh | 110 +- .../common/util/nvfp4_transpose.cuh | 1515 +++++++++++++++++ transformer_engine/common/util/ptx.cuh | 82 +- .../common/util/pybind_helper.h | 3 +- transformer_engine/common/utils.cuh | 20 + transformer_engine/pytorch/constants.py | 2 + .../pytorch/cpp_extensions/gemm.py | 20 + transformer_engine/pytorch/csrc/common.cpp | 30 + transformer_engine/pytorch/csrc/common.h | 83 +- .../pytorch/csrc/extensions/activation.cpp | 244 ++- .../pytorch/csrc/extensions/attention.cpp | 18 +- .../pytorch/csrc/extensions/bias.cpp | 48 +- .../pytorch/csrc/extensions/gemm.cpp | 20 +- .../pytorch/csrc/extensions/normalization.cpp | 270 ++- .../pytorch/csrc/extensions/pybind.cpp | 19 + transformer_engine/pytorch/csrc/pybind.h | 20 +- transformer_engine/pytorch/csrc/quantizer.cpp | 590 ++++++- .../pytorch/csrc/type_converters.cpp | 40 + transformer_engine/pytorch/csrc/util.cpp | 55 +- transformer_engine/pytorch/distributed.py | 263 ++- .../pytorch/experimental/__init__.py | 10 + .../pytorch/experimental/config.py | 201 +++ .../pytorch/experimental/gemm.py | 139 ++ .../pytorch/experimental/quantization.py | 203 +++ .../quantization_microblock_ref.py | 811 +++++++++ .../pytorch/experimental/utils.py | 30 + transformer_engine/pytorch/fp8.py | 105 ++ transformer_engine/pytorch/module/_common.py | 38 +- transformer_engine/pytorch/module/base.py | 15 +- .../pytorch/module/layernorm_linear.py | 43 +- .../pytorch/module/layernorm_mlp.py | 48 +- transformer_engine/pytorch/module/linear.py | 45 +- .../pytorch/ops/basic/basic_linear.py | 8 + transformer_engine/pytorch/tensor/__init__.py | 3 + .../tensor/_internal/nvfp4_tensor_base.py | 348 ++++ .../pytorch/tensor/mxfp8_tensor.py | 5 +- .../pytorch/tensor/nvfp4_tensor.py | 898 ++++++++++ .../pytorch/tensor/quantized_tensor.py | 4 + transformer_engine/pytorch/tensor/utils.py | 21 +- transformer_engine/pytorch/triton/pad.py | 94 + transformer_engine/pytorch/utils.py | 14 +- 92 files changed, 15060 insertions(+), 753 deletions(-) create mode 100644 benchmarks/benchmark_rht_cast.py create mode 100644 tests/cpp/operator/test_cast_nvfp4_transpose.cu create mode 100644 tests/pytorch/distributed/run_numerics_exact.py create mode 100644 tests/pytorch/distributed/test_numerics_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_module_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py create mode 100644 tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py create mode 100755 tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py create mode 100644 transformer_engine/common/gemm/config.cpp create mode 100644 transformer_engine/common/gemm/config.h create mode 100644 transformer_engine/common/hadamard_transform/hadamard_transform.cu create mode 100644 transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu create mode 100644 transformer_engine/common/include/transformer_engine/hadamard_transform.h create mode 100644 transformer_engine/common/recipe/nvfp4.cu create mode 100644 transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu create mode 100644 transformer_engine/common/util/nvfp4_transpose.cuh create mode 100644 transformer_engine/pytorch/experimental/__init__.py create mode 100644 transformer_engine/pytorch/experimental/config.py create mode 100644 transformer_engine/pytorch/experimental/gemm.py create mode 100644 transformer_engine/pytorch/experimental/quantization.py create mode 100644 transformer_engine/pytorch/experimental/quantization_microblock_ref.py create mode 100644 transformer_engine/pytorch/experimental/utils.py create mode 100644 transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py create mode 100644 transformer_engine/pytorch/tensor/nvfp4_tensor.py create mode 100644 transformer_engine/pytorch/triton/pad.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f40b28189..506bc83f0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake==3.21.0 pybind11[global] ninja + pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -43,7 +43,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -63,7 +63,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install pybind11[global] + run: pip install pybind11[global] nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: @@ -83,7 +83,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript + run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/benchmarks/benchmark_rht_cast.py b/benchmarks/benchmark_rht_cast.py new file mode 100644 index 000000000..9c47856f7 --- /dev/null +++ b/benchmarks/benchmark_rht_cast.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import pandas as pd +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +import transformer_engine.pytorch.cpp_extensions as ext + +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +scale_padding_to = 1 +permute_scale = False + +TORCH_TO_TE_FLOAT_MAP = { + torch.bfloat16: tex.DType.kBFloat16, +} + + +def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16): + # Generate random input data + M, K = shape + x = torch.randn([M, K], dtype=input_dtype, device="cuda") + + assert shape[0] % 16 == 0, "Shape must be divisible by 16" + assert shape[1] % 16 == 0, "Shape must be divisible by 16" + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + stochastic_rounding=stochastic_rounding, + ) + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, K), dtype=x.dtype, device=x.device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + with torch.no_grad(): + stmt = "kernel_func(input, output)" + globals_dict = { + "kernel_func": nvfp4_quantizer.update_quantized, + "input": x, + "output": x_nvfp4_sut, + } + + timing = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + ).blocked_autorange(min_run_time=5) + print(timing) + timing_us = timing.median * 1e6 + + input_nbytes = shape[0] * shape[1] * 2 # bf16 + output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4 + sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems + + total_nbytes = ( + 0 + + input_nbytes + * 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T)) + + 2 * 4 # Output 2 * float for scale & amax + + 2 * 4 # Input 2 * float + + output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T)) + + sf_nbytes * 2 # Scale factor + ) + + throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6) + + print( + f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:" + f" {throughput_GBps} GB/s" + ) + return timing_us, throughput_GBps + + +# Nsight Compute Profiling Command: +# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + args = parser.parse_args() + + if args.profile: + print("Profiling is enabled.") + else: + print("Profiling is disabled.") + + shapes = [ + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ] + + if args.profile: + shapes = [ + (16384, 6144), + ] + + data = [] + for stochastic_rounding in [True]: # , False]: + for shape in shapes: + print( + f"Running benchmark_func with shape {shape} and stochastic_rounding" + f" {stochastic_rounding}" + ) + timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding) + data.append( + [ + "benchmark_func", + shape, + stochastic_rounding, + timing_us, + throughput_GBps, + ] + ) + + df = pd.DataFrame( + data=data, + columns=[ + "kernel", + "shape", + "stochastic_rounding", + "timing_us", + "throughput(GB/s)", + ], + ) + print(df) + df.to_csv("benchmark_cast_nvfp4.csv", index=False) diff --git a/build_tools/utils.py b/build_tools/utils.py index 23fb56598..3d8ec462c 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -234,15 +234,18 @@ def get_cuda_include_dirs() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - version = cuda_version() - if os.getenv("NVTE_CUDA_ARCHS") is None: + archs = os.getenv("NVTE_CUDA_ARCHS") + if archs is None: + version = cuda_version() if version >= (13, 0): - os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" + archs = "75;80;89;90;100;100a;103a;120" + elif version >= (12, 9): + archs = "70;80;89;90;100;100a;103a;120" elif version >= (12, 8): - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" + archs = "70;80;89;90;100;100a;120" else: - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" - return os.getenv("NVTE_CUDA_ARCHS") + archs = "70;80;89;90" + return archs def cuda_version() -> Tuple[int, ...]: diff --git a/pyproject.toml b/pyproject.toml index 64ff4c5ce..8692ad961 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,8 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", -"torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 394273ca4..cdf0df888 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -31,6 +31,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 19889946a..e698e997a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 498c1d394..479d378ba 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -11,6 +11,7 @@ add_executable(test_operator test_cast_mxfp8_gated_swiglu.cu test_qdq.cu test_cast_mxfp8.cu + test_cast_nvfp4_transpose.cu test_cast_float8blockwise.cu test_dequantize_mxfp8.cu test_transpose.cu @@ -31,6 +32,13 @@ add_executable(test_operator test_swap_first_dims.cu ../test_common.cu) +# Add profiling and debug flags for CUDA compilation +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage +# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping + +# Find required packages find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 49bbf1655..380092144 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -81,6 +81,7 @@ void compute_ref(const ProcessingMethod processing_method, // Cache computations for (size_t i = i_min; i < i_max; ++i) { for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); @@ -310,12 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method, const double rel_tolerable_mismatches_limit = 0.0; size_t mismatches_scales = 0; - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + + compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); const size_t mismatches_elts = 32 * mismatches_scales; auto [atol, rtol] = getTolerances(otype); @@ -481,22 +483,22 @@ void performTest_x2(const ProcessingMethod processing_method, const double rel_tolerable_mismatches_limit = 0.0; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 464b77128..512ee7e81 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -267,19 +267,20 @@ void performTest_x1(const size_t rows, ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + } const size_t mismatches_elts = 32 * mismatches_scales; @@ -378,21 +379,22 @@ void performTest_x2(const size_t rows, const double rel_tolerable_mismatches_limit = 1.0e-4; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu new file mode 100644 index 000000000..e905a0064 --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -0,0 +1,741 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) { + std::vector input_t(cols * rows); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const size_t idx_t = j * rows + i; + input_t[idx_t] = input[idx]; + } + } + return input_t; +} + +// Compute the global encode scale factor for a given global amax +float compute_global_encode_scaling_factor_FP4(const float global_amax) { + constexpr float fp8_max = 448.0f; // 448.0f; + constexpr float fp4_max = 6.0f; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +// 1D Scaling: Original implementation with 1x16 blocks +template +void quantize_nvfp4_1d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + + constexpr size_t block_size_X = 16; + const size_t blocks_X = divide_round_up(cols, block_size_X); + + std::array cache_buffer; + for (size_t i = 0; i < block_size_X; ++i) { + cache_buffer[i] = 0.0f; + } + + for (size_t i = 0; i < rows; ++i) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t j_min = block_X * block_size_X; + const size_t j_max = j_min + block_size_X; + + // Find block amax + float block_amax = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx] = elt; + block_amax = std::max(block_amax, std::abs(elt)); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const float S_dec_b_fp8 = S_dec_b * S_enc; + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = static_cast(S_dec_b_fp8); + const float scale_reciprocal = S_enc_b_fp8; + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + + // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + } + } + } +} + +// Compute 2D mathematical scaling factors (8x8 for 128x128 input) +template +void compute_2d_mathematical_scales(float (*OP)(const float), + const InputType* const input, + const size_t rows, + const size_t cols, + const float global_amax, + std::vector>& math_scales) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + math_scales.resize(blocks_Y, std::vector(blocks_X)); + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Find 2D block amax over entire 16x16 region + float block_amax = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + block_amax = std::max(block_amax, std::abs(elt)); + } + } + + // Compute E4M3 scaling factor for this 16x16 block + const float S_dec_b = block_amax / 6.0f; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + math_scales[block_Y][block_X] = S_dec_b_fp8; + } + } +} + +// 2D Scaling: NEW implementation with proper replication +template +void quantize_nvfp4_2d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Step 1: Compute mathematical 8x8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr + if (scales != nullptr) { + // Each of the 128 rows gets scaling factors from its corresponding 16×16 block + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + } + + // Step 3: Apply quantization using the mathematical scaling factors + std::array, block_size_Y> cache_buffer; + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Get the scaling factor for this block + const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float scale_reciprocal = S_enc_b_fp8; + + // Process and cache data for this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx_y][cache_idx_x] = elt; + } + } + + // Apply scaling to all elements in this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x1 = j - j_min; + const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1); + + const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1]; + const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ? + cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f; + + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + } + } + } + } +} + +// Wrapper function that calls appropriate implementation based on 2D flag +template +void quantize_nvfp4(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_2d_quantization = false) { + if (use_2d_quantization) { + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } else { + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } +} + +template +void compute_ref(float (*OP)(const float), + const InputType* input, + fp4e2m1x2* output, + fp4e2m1x2* output_t, + fp8e4m3* scales, + fp8e4m3* scales_t, + const float global_amax, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const size_t scales_stride_t, + const bool use_2d_quantization = false) +{ + std::vector input_t = create_transpose(input, rows, cols); + + if (use_2d_quantization) { + // Step 1: Compute mathematical 8×8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Generate scales (128×8) by replicating row-wise + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + + // Step 3: Generate scales_t (128×8) with proper transposed block mapping + for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data + const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X + for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate + const size_t scale_idx = i * scales_stride_t + block_Y_new; + scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig]; + } + } + + // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d + // (This part processes the actual FP4 data using the mathematical scaling factors) + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + + } else { + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + } +} + +void compare_nvfp4_tensors(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8) { + std::vector mismatch_messages; + size_t total_mismatches = 0; + + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + total_mismatches++; + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); + + // Optional: limit number of detailed messages to avoid overwhelming output + if (mismatch_messages.size() <= 100) { + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } + } + } + } + + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > 100) { + std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; + } +} + +// Optional: Function to dump tensor data to files for detailed analysis +void dump_nvfp4_tensor_data(const std::string& prefix, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + std::string test_file = prefix + "_test.txt"; + std::string ref_file = prefix + "_ref.txt"; + std::string diff_file = prefix + "_diff.txt"; + + std::ofstream test_out(test_file); + std::ofstream ref_out(ref_file); + std::ofstream diff_out(diff_file); + + if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + const int pos = idx + k; + + test_out << "pos[" << pos << "] = " << t << std::endl; + ref_out << "pos[" << pos << "] = " << r << std::endl; + diff_out << "pos[" << pos << "] test=" << t << " ref=" << r + << " abs_diff=" << fabs(t - r) + << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl; + } + } + } + std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl; + } else { + std::cout << "WARNING: Could not open files for tensor data dump" << std::endl; + } +} + +void print_detailed_tensor_comparison(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n", + name.c_str(), rows, cols, rows * cols); + + const int total_elements = rows * cols; + const int check_count = 128; + + printf("--- FIRST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = 0; i < std::min(check_count, total_elements); ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + + if (total_elements > 2 * check_count) { + printf("\n--- LAST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = total_elements - check_count; i < total_elements; ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + } + printf("==================================\n"); +} + +void compareResults_nvfp4(const Tensor &test, + const void *ref, const void *ref_t, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + + // Print detailed element-by-element comparison + // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); + // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows); + + // Optionally dump tensor data to files for detailed analysis + if (dump_data) { + dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + + compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); +} + +template +void performTest(float (*OP)(const float), + const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + // Use get_scale_tensor_dims for NVFP4 scale tensor dimensions + // Now that CheckScaleTensorShape is fixed, this should work correctly + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + Tensor input("input", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); + std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); + std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); + std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + + fillCase(&input, InputsFillCase::uniform); + + // Find global amax + float amax = 0.0f; + const InputType* input_dptr = input.rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + amax = fmaxf(amax, static_cast(input_dptr[idx])); + } + } + // Set 2nd stage NVFP4 scaling factor + output.set_scale(amax); + + bool use_2d_quantization = false; + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + output.scale(), + rows, + cols, + scales_stride, + scales_stride_t, + use_2d_quantization); + + QuantizationConfigWrapper quant_config; + + // Initialize stochastic rounding + Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); + rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed + rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence + rng_state.from_cpu(); + quant_config.set_stochastic_rounding(false); + quant_config.set_rng_state(rng_state.data()); + + // Set 2D quantization based on compile-time flag + quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + + // Call appropriate function based on operation type + // Activation functions take 3 parameters (input, output, stream) + // nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream) + if (OP == &gelu) { + nvte_gelu(input.data(), output.data(), 0); + } else if (OP == &silu) { + nvte_silu(input.data(), output.data(), 0); + } else if (OP == &relu) { + nvte_relu(input.data(), output.data(), 0); + } else if (OP == &qgelu) { + nvte_qgelu(input.data(), output.data(), 0); + } else if (OP == &srelu) { + nvte_srelu(input.data(), output.data(), 0); + } else { + nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err)); + } + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + const double atol = 0.05; + const double rtol = 0.1; + + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); + + const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_ptr = ref_scales.get(); + const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + scale_mismatches_num); + + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, + {128, 256}, + {8192, 128}, + {2048, 160}, + {8, 32, 1024}, + {16, 8, 4, 512}, + {1024, 16384}, + {4096, 13312}, +}; + +// Only GeLU activation tests are supported +std::vector Activation_types = { + ActivationType::Identity, + ActivationType::GeLU, + ActivationType::SiLU, + ActivationType::ReLU, + ActivationType::QGeLU, + ActivationType::SReLU, +}; + +} // namespace + +class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam + , + transformer_engine::DType>> {}; + +TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ActivationType Act_type = std::get<0>(GetParam()); + const auto tensor_dims = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + + // Skip tests if the input tensor is 1D + if (tensor_dims.size() < 2) { + GTEST_SKIP(); + } + + // Forward activations + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + performTest(OP, tensor_dims); + ); +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "CAST_ONLY"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kBFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index f974d9083..cdbfb05b3 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -107,6 +107,10 @@ size_t DIVUP(const size_t &x, const size_t &y){ return (((x) + ((y)-1)) / (y)); } +size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){ + return DIVUP(x, y) * y; +} + struct scale_inv_meta { std::vector shape; DType type; @@ -143,21 +147,71 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul, 4ul}; - { - auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; - alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + + ret_rowwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type = DType::kFloat8E8M0; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } - { - auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; - alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; - ret_colwise.shape = {scale_dim_0, scale_dim_1}; + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + NVTE_CHECK(last_dim % 32 == 0); + NVTE_CHECK(first_dim % 32 == 0); + + scale_inv_meta ret_rowwise, ret_colwise; + + size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; + + size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; + + ret_rowwise.type = DType::kFloat8E4M3; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + ret_colwise.type = DType::kFloat8E4M3; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + ret_rowwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0; ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); @@ -176,13 +230,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -202,13 +256,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(first_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(first_dim, 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(last_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(last_dim, 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -250,14 +304,15 @@ Tensor::Tensor(const std::string& name, NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING + || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } } else { - // Same shape for MX + // Same shape for MX and NVFP4 for (size_t i = 0; i < shape.ndim; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } @@ -283,10 +338,13 @@ Tensor::Tensor(const std::string& name, std::fill_n(cpu_data_columnwise_.get(), total_size, 0); } } - tensor_.set_rowwise_data(dptr_rowwise, type, shape); - tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); - if (isFp8Type(type)) { + const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); + tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); + + if (isFp8Type(type) || isFp4Type(type)) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) cudaMemset(amax, 0, sizeof(float)); @@ -305,13 +363,19 @@ Tensor::Tensor(const std::string& name, } if (columnwise) { tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); + std::vector{1}); columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(normalized_shape, tensor_.scaling_mode()); + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + // Used for NVFP4 second stage scaling + cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + cudaMemset(scale, 0, sizeof(float)); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes(); auto scale_shape = rowwise_scale_meta.shape; @@ -346,13 +410,16 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { + const DType colwise_type = tensor_.dtype(); + + const size_t colwise_size = bytes(s, colwise_type); cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); + tensor_.get_columnwise_data().data_ptr, + colwise_size, + cudaMemcpyDeviceToHost); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), @@ -364,8 +431,7 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -394,15 +460,15 @@ void Tensor::from_cpu() const { cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { if (tensor_.amax() != nullptr){ cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -419,7 +485,7 @@ void Tensor::from_cpu() const { } void Tensor::set_scale(float scale) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { NVTE_CHECK(scale_cpu_data_); if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; @@ -429,7 +495,7 @@ void Tensor::set_scale(float scale) { } void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { NVTE_CHECK(rowwise_scale_inv_cpu_data_); } @@ -437,8 +503,7 @@ void Tensor::set_scale_inv(float scale_inv) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(tensor_.shape(), tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); if (num_scales == 1) { @@ -468,7 +533,8 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) + || isFp4Type(dtype()) && isFp4Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), @@ -681,12 +747,30 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) +template +struct CastToType; + +template <> +struct CastToType { + using type = int; +}; + +template <> +struct CastToType { + using type = float; +}; + +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + using UpcastType = typename CastToType::type; + auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); + + const size_t N = row_blocks * col_blocks; const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); @@ -696,11 +780,31 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - const int test_val = static_cast(test[idx]); - const int ref_val = static_cast(ref[idx]); - const int abs_delta = std::abs(test_val - ref_val); + float t, r; + + bool assertion = false; - if (abs_delta > atol) { + if (std::is_same::value) { + t = static_cast(test[idx]); + r = static_cast(ref[idx]); + assertion = std::abs(t - r) > atol; + } else { + t = static_cast(*reinterpret_cast(&test[idx])); + r = static_cast(*reinterpret_cast(&ref[idx])); + const bool mismatch = (fabs(t - r) > atol_fp8e4m3) + && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); + if (mismatch) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + } + if (assertion) { mismatches_num++; mismatch_indices.push_back(idx); } @@ -708,8 +812,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, std::cout << "Error in " << name << std::endl; for (const int index : mismatch_indices) { std::cout << "Mismatch at (" << index << "):" - << static_cast(test[index]) << " vs " - << static_cast(ref[index]) << std::endl; + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; } GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " << tolerable_mismatches_limit << "."; @@ -718,6 +822,22 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } +// Instantiate templates +template +void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +template +void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + + std::pair getTolerances(const DType type) { switch(type) { case DType::kFloat32: @@ -873,6 +993,10 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } +bool isFp4Type(DType type) { + return type == DType::kFloat4E2M1; +} + int32_t getDeviceComputeCapability() { cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); @@ -894,7 +1018,8 @@ std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols) { - const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool is_rowwise = (block_size_rows == 1) + && ((block_size_cols == 32) || (block_size_cols == 16)); const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index d1e273c6d..b8993dfb6 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -62,6 +62,8 @@ using fp8e5m2 = __nv_fp8_e5m2; using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif template @@ -223,7 +225,9 @@ class Tensor { float scale() const { if(scale_cpu_data_) { - NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); + NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), + "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -237,6 +241,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -250,6 +256,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -304,10 +312,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; -constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +constexpr size_t scale_tensor_alignment_X_colwise = 128; inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; @@ -456,12 +464,14 @@ void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); + std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols); @@ -484,6 +494,7 @@ const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +bool isFp4Type(DType type); int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; @@ -561,7 +572,7 @@ constexpr int32_t blackwellComputeCapability = 100; SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ printf("dtype: %d\n", static_cast(dtype)); \ - NVTE_ERROR("Invalid type MARKED TEST."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -580,7 +591,7 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 2."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ @@ -588,7 +599,7 @@ constexpr int32_t blackwellComputeCapability = 100; using namespace transformer_engine; \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 3."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -613,5 +624,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 4."); \ + NVTE_ERROR("Invalid type."); \ } diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 21aab6336..a4aa74bd8 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -9,6 +9,7 @@ import os import sys from functools import wraps +import math import transformer_engine.pytorch as te import torch @@ -20,10 +21,15 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, Format, Recipe, + QParams, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.distributed import gather_along_first_dim from run_layer_with_overlap import _compare_tensors SEQ_LEN, BATCH_SIZE = 16, 16 @@ -47,6 +53,14 @@ ) +def nvfp4_vanilla(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams() + nvfp4_recipe.fp4_quant_fwd_weight = QParams() + nvfp4_recipe.fp4_quant_bwd_grad = QParams() + return nvfp4_recipe + + # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "fp8": @@ -59,6 +73,8 @@ def quantization_recipe() -> Recipe: return Float8CurrentScaling() if QUANTIZATION == "fp8_block_scaling": return Float8BlockScaling() + if QUANTIZATION == "nvfp4": + return nvfp4_vanilla() return te.fp8.get_default_fp8_recipe() @@ -96,10 +112,14 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"): SEQ_LEN = 32 BATCH_SIZE = 32 HIDDEN_SIZE = 128 + # For fp8 block scaling, block size is 128, + # and to make low precision TP work, input tensor + # must be 128x128 divisible to be eligible for + # low precision All-Gather when needed elif QUANTIZATION == "fp8_block_scaling": SEQ_LEN = 128 BATCH_SIZE = 128 @@ -107,6 +127,7 @@ def main(argv=None, namespace=None): test_dict = [ test_quantizer, + test_quantized_all_gather, test_linear, test_layernorm, test_layernorm_linear, @@ -176,6 +197,9 @@ def _get_tolerances(dtype): # row parallel & sequence parallel, because we do the all_gather in backward pass if QUANTIZATION == "fp8_cs": return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION == "nvfp4": + # TODO(zhongboz): investigate why the tolerance is so large + return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} @@ -326,24 +350,36 @@ def _alloc_main_grad(model_single_node, model_distributed): ############################################### # Quantizer # ############################################### -def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): +def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size): """ quantizer is the reference quantizer on a single GPU. quantizer_dist is the distributed quantizer to be tested on multiple GPUs. """ if quantizer_class == Float8CurrentScalingQuantizer: quantizer_dist = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=True, amax_reduction_group=tp_group, ) quantizer = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=False, ) return quantizer, quantizer_dist + elif quantizer_class == NVFP4Quantizer: + quantizer_dist = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=True, + amax_reduction_group=tp_group, + ) + quantizer = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=False, + amax_reduction_group=None, + ) + return quantizer, quantizer_dist else: raise ValueError(f"Unsupported quantizer class: {quantizer_class}") @@ -414,6 +450,194 @@ def test_quantizer(): _test_quantizer(input_dtype, fp8_dtype) +############################################ +# Quantized All-Gather # +############################################ + + +def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape): + """ + Zero padding the scale_inv. + scale_inv shape is the padded shape, but not zero padded + unpadded_shape is the original shape before padding + """ + dim0, dim1 = scale_inv.shape + unpadded_dim0, unpadded_dim1 = unpadded_shape + pad_dim0 = (128 - unpadded_dim0 % 128) % 128 + pad_dim1 = (4 - unpadded_dim1 % 4) % 4 + new_dim0 = unpadded_dim0 + pad_dim0 + new_dim1 = unpadded_dim1 + pad_dim1 + + assert dim0 == new_dim0 + assert dim1 == new_dim1 + + # return input if no padding is needed + if pad_dim0 == 0 and pad_dim1 == 0: + return scale_inv + + # unpad first to remove random bits from torch empty + scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous() + # using torch padding + new_scale_inv = torch.nn.functional.pad( + scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0 + ) + + assert new_scale_inv.shape == (new_dim0, new_dim1) + + return new_scale_inv + + +def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise): + """ + Calculate the unpadded shape of the scale_inv tensor. + """ + M, K = 1, 1 + M = math.prod(input_shape[:-1]) + K = input_shape[-1] + + if quantizer_cls == NVFP4Quantizer: + if columnwise: + outer = K + inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + outer = M + inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_cls}") + + +@run_distributed_test() +def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2 + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + # x_hp_cpu[M - 1, N - 1] = 1e4 + + # get the unpadded shapes + unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False) + unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True) + + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the entire input + if WORLD_RANK == 0: + x_low_precision_single = quantizer(x_hp_rank0) + + # run all-gather with a quantizer as input for quantized all-gather + x_low_precision_total, _ = gather_along_first_dim( + x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist + ) + + # check the outputs + if WORLD_RANK == 0: + # assert all data and scale_inv are the same + torch.testing.assert_close( + x_low_precision_single._rowwise_data, + x_low_precision_total._rowwise_data, + rtol=0.0, + atol=0.0, + ) + # check the rowwise scale without any padding + unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape + unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_rowwise_scale_inv_ref, + unpadded_rowwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + x_low_precision_single._columnwise_data, + x_low_precision_total._columnwise_data, + rtol=0.0, + atol=0.0, + ) + unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape + unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_columnwise_scale_inv_ref, + unpadded_columnwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + + +def test_quantized_all_gather(): + """ + Run quantized all-gather tests with various configurations. + """ + # skip this test for other quantization schemes + is_nvfp4 = QUANTIZATION == "nvfp4" + # add other recipes for testing if needed + if not is_nvfp4: + return + + input_dtypes = [torch.bfloat16] + fp4_dtype = [tex.DType.kFloat4E2M1] + fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + quantizer_cls_nvfp4 = [NVFP4Quantizer] + # add FP8 quantizers if needed + quantizer_cls_fp8 = [] + + low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype + quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8 + + for quantizer_cls in quantizer_cls_list: + for input_dtype in input_dtypes: + for low_precision_dtype in low_precisio_dtypes: + _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls) + + ############################################ # Linear # ############################################ @@ -514,10 +738,11 @@ def test_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"delay_wgrad_compute": True}, {"save_original_input": True}, ] + for kwargs in kwargs_list: if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": continue @@ -693,11 +918,12 @@ def test_layernorm_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, {"delay_wgrad_compute": True}, ] + for kwargs in kwargs_list: for parallel_mode in ["column"]: for sequence_parallel in [False, True]: @@ -799,7 +1025,7 @@ def test_layernorm_mlp(): {"normalization": "RMSNorm"}, {"zero_centered_gamma": True}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"activation": "relu"}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, @@ -897,7 +1123,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True}, {"qkv_weight_interleaved": False}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"fuse_qkv_params": True}, {"activation": "relu"}, ] diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py new file mode 100644 index 000000000..b1722b79a --- /dev/null +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -0,0 +1,718 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import datetime +import os +import sys +from functools import wraps +import math + +import transformer_engine.pytorch as te +import torch +from torch import nn +import torch.distributed as dist +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + NVFP4BlockScaling, + Format, + Recipe, + QParams, +) +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from run_layer_with_overlap import _compare_tensors + + +BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128 +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +LOSS_FN = nn.MSELoss() +QUANTIZATION = None + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "nvfp4": + return nvfp4_rht_and_2d_quantization() + raise ValueError(f"Unsupported quantization: {QUANTIZATION}") + + +def setup_environment_for_reference(): + if QUANTIZATION == "nvfp4": + os.environ["QAT_PARAMS"] = "9003" + else: + raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}") + + +def cleanup_environment(): + if "QAT_PARAMS" in os.environ: + del os.environ["QAT_PARAMS"] + + +def main(argv=None, namespace=None): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + parser = argparse.ArgumentParser() + parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--hidden-size", type=int, default=128) + parser.add_argument("--out-size", type=int, default=128) + args = parser.parse_args(argv, namespace) + + # Quantization scheme + QUANTIZATION = args.quantization + BATCH_SIZE = args.batch_size + HIDDEN_SIZE = args.hidden_size + OUT_SIZE = args.out_size + + test_dict = [ + test_linear, + test_layernorm_linear, + ] + + for test in test_dict: + test() + dist.destroy_process_group() + return 0 + + +def run_distributed_test(test_name=None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + name = test_name if test_name is not None else func.__name__ + + dist_print(f"Starting test {name} with args {args} and {kwargs}") + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + func(*args, **kwargs) + + dist.barrier() + dist_print(f"Passed test {name}") + + return wrapper + + return decorator + + +def dist_print(msg, src=None, end="\n", error=False): + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + + +############################################ +# Linear # +############################################ +class TestDistributedLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'row' or 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + + QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # turn on weight quantization cache when fusing wgrad accumulation + enable_weight_cache = fuse_wgrad_accumulation + run_num_steps = 1 if not fuse_wgrad_accumulation else 5 + + x, w, bias, gradient = TestDistributedLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=1 if not fuse_wgrad_accumulation else 5, + enable_weight_cache=fuse_wgrad_accumulation, + ) + + # run the reference + setup_environment_for_reference() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=run_num_steps, + enable_weight_cache=enable_weight_cache, + ) + # Clean up env + cleanup_environment() + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_linear(): + """Run linear layer tests with various configurations.""" + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": + continue + for parallel_mode in ["column", "row"]: + for sequence_parallel in [False, True]: + _test_linear(parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# LayerNormLinear # +############################################ +class TestDistributedLayerNormLinearBase(TestDistributedLinearBase): + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # run multiple steps currently not supported for LayerNormLinear + run_num_steps = 1 + + x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + + # run the reference + setup_environment_for_reference() + with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( + TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + ) + # Clean up env + cleanup_environment() + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_layernorm_linear(): + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + for parallel_mode in ["column"]: + for sequence_parallel in [False, True]: + _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 8ca1fcc1c..11fe4333b 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -27,6 +27,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex @@ -34,17 +35,20 @@ # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, make_recipe +from utils import dtype_tols, make_recipe, quantization_tols # Check what quantization schemes are supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available() quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") +if nvfp4_available: + quantization_list.append("nvfp4") @functools.cache @@ -115,6 +119,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -437,7 +449,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -609,7 +621,7 @@ def _test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1ff5aff99..d09c530cb 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -31,6 +31,7 @@ fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( FP8GlobalStateManager.is_fp8_block_scaling_available() ) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -51,7 +52,9 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) +@pytest.mark.parametrize( + "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"] +) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -61,4 +64,6 @@ def test_distributed(quantization): pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) _run_test(quantization) diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py new file mode 100644 index 000000000..890a24804 --- /dev/null +++ b/tests/pytorch/distributed/test_numerics_exact.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager + +""" + Distributed numerics tests + + This numerical test aims for zero tolerance test for absolute confidence in numerics. + In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise + result with the native silicon. For distrbuted test cases, we can do the same by thing + by comparing BF16 AG results with the low precision AG results at layer level. +""" + + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(quantization, batch_size, hidden_size, out_size): + test_path = TEST_ROOT / "run_numerics_exact.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + test_cmd += ["--quantization", quantization] + test_cmd += ["--batch-size", str(batch_size)] + test_cmd += ["--hidden-size", str(hidden_size)] + test_cmd += ["--out-size", str(out_size)] + + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 + + +all_boolean = [True, False] + + +@pytest.mark.parametrize("quantization", ["nvfp4"]) +@pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (64, 128, 128), + (128, 128, 128), + (128, 256, 256), + (512, 1024, 768), + (512, 256, 1024), + (2048, 2048, 2048), + ], +) +def test_distributed(quantization, batch_size, hidden_size, out_size): + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + _run_test(quantization, batch_size, hidden_size, out_size) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py new file mode 100644 index 000000000..a9e73aaf9 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -0,0 +1,243 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def check_nvfp4_gemm_versus_reference( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, + accumulate: bool, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, +): + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input tensors + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + x = torch.randn(x_shape, dtype=x_dtype, device=device) + w = torch.randn(w_shape, dtype=w_dtype, device=device) + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Native TE NVFP4 quantization + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + # Quantize x and w + x_nvfp4_native = x_quantizer.make_empty( + x_shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native) + w_nvfp4_native = w_quantizer.make_empty( + w_shape, dtype=w_dtype, device=device, requires_grad=False + ) + w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native) + + # Extract quantized data from native NVFP4Tensors + qx_data = ( + x_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if x_columnwise + else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + qw_data = ( + w_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if w_columnwise + else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + sx_native = ( + x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv + ) + sw_native = ( + w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv + ) + + # Trim quantized data to match the actual tensor dimensions (remove padding) + qx_data = qx_data[:M, :] + qw_data = qw_data[:N, :] + + # NVFP4 uses 16-element blocks, trim scales to remove padding + block_length = 16 # NVFP4 uses 16-element blocks + expected_sx_cols = expected_sw_cols = K // block_length + # Trim the scales to remove padding + sx_trimmed = sx_native[:M, :expected_sx_cols] + sw_trimmed = sw_native[:N, :expected_sw_cols] + + # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn + # for the reference GEMM to work correctly + sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) + sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) + + # Create reference quantizer for reference GEMM + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=True, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + + # Create reference quantized tensors needed by reference GEMM + x_nvfp4_ref = ref_quantizer.quantize(x) + w_nvfp4_ref = ref_quantizer.quantize(w) + + # Reference GEMM using quantizer's qgemm method + y_ref = ref_quantizer.qgemm( + qx=qx_data, + qw=qw_data, + m_params=None, # MMParams not used in reference + out_dtype=out_dtype, + sx=sx_trimmed, + sw=sw_trimmed, + bias=None, # No bias for this test + out=out.clone() if accumulate else None, + accumulate=accumulate, + gemm_type=None, # GEMMType not used in reference + qresult_x=x_nvfp4_ref, + qresult_w=w_nvfp4_ref, + ) + + # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) + # Allocate cuBLAS workspace + workspace = torch.empty(4, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + bias = None + bias_dtype = TE_DType[torch.bfloat16] + use_gelu = False + gelu_input = None + use_grad = False + use_split_accumulator = False + + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y_native are not the same tensor + assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native) + + # Compare results with some tolerance + torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (False, False), # Only rowwise x rowwise is supported by reference GEMM + # Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization + # Columnwise layouts are not supported by the reference implementation + ], + ids=["rowxrow"], +) +def test_nvfp4_gemm_versus_reference( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + accumulate: bool, + is_x_columnwise: bool, + is_w_columnwise: bool, +): + check_nvfp4_gemm_versus_reference( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, + accumulate=accumulate, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py new file mode 100644 index 000000000..ae9975839 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -0,0 +1,559 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import pytest +import torch +import transformer_engine as te +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.distributed import fp8_autocast +from transformer_engine.common import recipe + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +class GetRecipes: + @staticmethod + def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True) + return nvfp4_recipe + + @staticmethod + def nvfp4_2d_quantization_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False) + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + @staticmethod + def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + return GetRecipes.nvfp4_rht_and_2d_quantization() + elif with_rht: + return GetRecipes.nvfp4_rht_only() + elif with_2d_quantization: + return GetRecipes.nvfp4_2d_quantization_only() + else: + return GetRecipes.nvfp4_vanilla() + + +def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + os.environ["QAT_PARAMS"] = "9003" + elif with_rht: + os.environ["QAT_PARAMS"] = "960109" + elif with_2d_quantization: + os.environ["QAT_PARAMS"] = "9002" + else: + os.environ["QAT_PARAMS"] = "6010" + + +def cleanup_environment(): + if "QAT_PARAMS" in os.environ: + del os.environ["QAT_PARAMS"] + + +def reset_rng_states(): + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def check_nvfp4_module_versus_reference( + module_class, + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 module against reference implementation. + + Args: + module_class: te.Linear or te.LayerNormLinear + in_features: Input feature dimension + out_features: Output feature dimension + bias: Whether to use bias + x_dtype: Input tensor dtype + num_steps: Number of forward/backward steps to test + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + cleanup_environment() + reset_rng_states() + + # Create native module + print("\nCreate native module") + if module_class == te.pytorch.Linear: + native_module = te.pytorch.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.pytorch.LayerNormLinear: + native_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + else: + raise ValueError(f"Unsupported module class: {module_class}") + + # Create reference module with same weights + setup_environment_for_reference(with_rht, with_2d_quantization) + reset_rng_states() + + # Create reference module + print("Create reference module") + if module_class == te.pytorch.Linear: + ref_module = te.pytorch.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.pytorch.LayerNormLinear: + ref_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + + # Sync weights between native and reference modules + with torch.no_grad(): + # Copy main weight and bias parameters + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + + # Copy layer norm parameters if they exist + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + + # Training loop comparison + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + cleanup_environment() + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + # enable weight cache by giving is_first_microbatch + y_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + setup_environment_for_reference(with_rht, with_2d_quantization) + with fp8_autocast( + enabled=True, fp8_recipe=nvfp4_recipe + ): # Exact recipe does not play a role here + y_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + # Store results + native_outputs.append( + { + "output": y_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results across all steps + for step in range(num_steps): + native_out = native_outputs[step] + ref_out = ref_outputs[step] + + # Compare outputs + torch.testing.assert_close( + native_out["output"], + ref_out["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + + # Compare input gradients + torch.testing.assert_close( + native_out["input_grad"], + ref_out["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:" + f" {ref_out['input_grad']}" + ), + ) + + # Compare weight gradients + torch.testing.assert_close( + native_out["weight_grad"], + ref_out["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']}," + f" Ref: {ref_out['weight_grad']}" + ), + ) + + # Compare bias gradients + if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None: + torch.testing.assert_close( + native_out["bias_grad"], + ref_out["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + # Clean up + cleanup_environment() + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + (512, 512), + (768, 3072), + (1024, 4096), + ], +) +# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + """Test NVFP4 Linear module against reference implementation.""" + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_module_versus_reference( + module_class=te.pytorch.Linear, + in_features=in_features, + out_features=out_features, + bias=bias, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) + + +def check_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 LayerNormLinear module against reference implementation, + including ln_out. + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + cleanup_environment() + reset_rng_states() + + # Native module + native_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Reference module + setup_environment_for_reference(with_rht, with_2d_quantization) + reset_rng_states() + ref_module = te.pytorch.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Sync weights and LN params + with torch.no_grad(): + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + if ( + native_module.layer_norm_weight is not None + and ref_module.layer_norm_weight is not None + ): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None: + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + cleanup_environment() + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + setup_environment_for_reference(with_rht, with_2d_quantization) + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + y_ref, ln_out_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + native_outputs.append( + { + "output": y_native.detach().clone(), + "ln_out": ln_out_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "ln_out": ln_out_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results + for step in range(num_steps): + n = native_outputs[step] + r = ref_outputs[step] + torch.testing.assert_close( + n["output"], + r["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + torch.testing.assert_close( + n["ln_out"], + r["ln_out"], + atol=1e-6, + rtol=1e-6, + msg=f"LN output mismatch at step {step}", + ) + torch.testing.assert_close( + n["input_grad"], + r["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Input gradient mismatch at step {step}", + ) + torch.testing.assert_close( + n["weight_grad"], + r["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Weight gradient mismatch at step {step}", + ) + if bias and n["bias_grad"] is not None and r["bias_grad"] is not None: + torch.testing.assert_close( + n["bias_grad"], + r["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + cleanup_environment() + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + ], +) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1], ids=["single_step"]) +@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_layernorm_linear_versus_reference( + in_features=in_features, + out_features=out_features, + bias=bias, + normalization=normalization, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py new file mode 100644 index 000000000..dc3c4a4e9 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -0,0 +1,495 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + NVFP4Quantizer, +) +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype + + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + # Reference quantization + quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=quant_tile_shape, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + if x_nvfp4_ref.data_t is not None + else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # # largest tile + (8192, 8192), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"]) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] +) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_transpose=return_transpose, + swizzled_scale=swizzled_scale, + use_cpp_allocator=use_cpp_allocator, + with_2d_quantization=with_2d_quantization, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + extrema_high: bool, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (16, 128), + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_boundary_values( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + """ + Stress rounding/threshold behavior by placing values just below/above + many potential bin edges within each 16-element microblock. + Validates native vs reference byte-for-byte and scale parity. + """ + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 123 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Construct a single row with paired boundary values: v-eps, v+eps + # spanning a wide dynamic range to exercise clipping and multiple bins. + # Ensure even N and N is multiple of 16 for microblocks, which holds for 128. + base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device) + eps = torch.full_like(base, 1e-3) + # Avoid zero eps for very small magnitudes + eps = torch.maximum(eps, 1e-4 * torch.ones_like(base)) + lower = base - eps + upper = base + eps + row = torch.empty(N, dtype=torch.float32, device=device) + row[0::2] = lower + row[1::2] = upper + x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim any padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 17 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Start from a contiguous tensor, then make a non-contiguous view by transpose + x_base = torch.randn((M, N), dtype=x_dtype, device=device) + x_nc = x_base.t() # shape (N, M), non-contiguous + assert not x_nc.is_contiguous() + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x_nc) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + # Quantized must match + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py new file mode 100644 index 000000000..bb542456e --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -0,0 +1,255 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py. +# Separate to make sure all the functionalities are working as expected. +# Otherwise reference implementation will get messy. + +# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality +# together with the quantization functionality. + +from typing import Tuple +import math + +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.tensor.nvfp4_tensor import ( + NVFP4Quantizer, +) +from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype + +import pytest +import torch + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + contiguous: bool, + return_transpose: bool, + use_cpp_allocator: bool, + swizzled_scale: bool = False, + hadamard_dimension: int = 16, + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, +) -> None: + assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled." + + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + x = x.transpose(0, 1) if not contiguous else x + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + amax_rowwise = x_nvfp4_sut._amax_rowwise + amax_colwise = x_nvfp4_sut._amax_columnwise + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + # Reference quantization using NVFP4QuantizerRef with built-in RHT + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=with_rht, + with_random_sign_mask=with_random_sign_mask, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + ref_amax_rowwise = x_nvfp4_ref.global_amax_row + + if return_transpose: + assert x_nvfp4_ref.data_t is not None + assert x_nvfp4_ref.scale_t is not None + qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8) + # Compute transpose amax using the same reference quantizer + x_t_for_amax = ( + ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous() + ) + ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1) + else: + qx_t_ref = None + sx_t_ref = None + ref_amax_colwise_t = None + + torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # Real shapes, + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_rht_with_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=True, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +): + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=False, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py new file mode 100755 index 000000000..46077eb20 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() + +seed = 12345 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +_FP4_LUT = torch.tensor( + [ + 0.0, # 0: 0000 - zero + 0.5, # 1: 0001 - smallest positive normal + 1.0, # 2: 0010 + 1.5, # 3: 0011 + 2.0, # 4: 0100 + 3.0, # 5: 0101 + 4.0, # 6: 0110 + 6.0, # 7: 0111 - largest positive normal + -0.0, # 8: 1000 - negative zero + -0.5, # 9: 1001 - smallest negative normal + -1.0, # 10: 1010 + -1.5, # 11: 1011 + -2.0, # 12: 1100 + -3.0, # 13: 1101 + -4.0, # 14: 1110 + -6.0, # 15: 1111 - largest negative normal + ], + dtype=torch.float32, +) + + +def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: + # Convert FP4 indices to their corresponding floating point values + # Each index (0-15) represents a 4-bit FP4 value in E2M1 format + # Values based on the FP4 E2M1 specification + fp4_lut = _FP4_LUT.to(fp4.device) + return fp4_lut[fp4.to(torch.long)] + + +def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: + sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) + dqx = fp4_to_fp32(unpack_fp4(qx)) + sf = sf[: dqx.shape[0], : dqx.shape[1]] + dequant = dqx * sf * (amax / (6.0 * 448)) + return dequant + + +def RHT(x: torch.Tensor) -> torch.Tensor: + def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [ + 1.0, + 1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + ], + dtype=torch.float32, + ) + + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + rht_dim = 16 + # Build H and scale + H = _build_hadamard_matrix(rht_dim, x.device, x.dtype) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + +def quantize_fp4( + x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool +) -> torch.Tensor: + nvfp4_quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=use_RHT, + with_post_rht_amax=True, + stochastic_rounding=use_stochastic_rounding, + with_2d_quantization=use_2D, + ) + + x_nvfp4_sut = nvfp4_quantizer(x) + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + assert x_nvfp4_sut._columnwise_data is not None + qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._columnwise_scale_inv is not None + sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv + + return qx, sx, qx_t, sx_t + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool +) -> None: + device = "cuda" + torch.manual_seed(seed) + n_iters = 50 + + x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1 + y = x.t().contiguous() + if use_RHT: + y = RHT(y) + amax = torch.max(torch.abs(x)).float() + q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4( + x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT + ) + dq_rn = dequantize_fp4(q_rn, s_rn, amax) + dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax) + error_rn = (dq_rn - x).float() + me_rn = torch.sqrt((error_rn * error_rn).mean()) + error_t_rn = (dq_t_rn - y).float() + me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) + sr_result = torch.zeros_like(x).float() + sr_t_result = torch.zeros_like(x).float().t().contiguous() + for i in range(n_iters): + q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( + x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT + ) + + dq_sr = dequantize_fp4(q_sr, s_sr, amax) + dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax) + + sr_result += dq_sr.float() + sr_t_result += dq_t_sr.float() + + # sr_result_tmp = sr_result / (i + 1) + # error_sr = (sr_result_tmp - x).float() + # me_sr = torch.sqrt((error_sr * error_sr).mean()) + # sr_t_result_tmp = sr_t_result / (i + 1) + # error_t_sr = (sr_t_result_tmp - y).float() + # me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + # print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + # print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + + # Get the mean result of the stochastic rounding + # It should be more accurate than the RN result + sr_result /= n_iters + error_sr = (sr_result - x).float() + me_sr = torch.sqrt((error_sr * error_sr).mean()) + sr_t_result /= n_iters + error_t_sr = (sr_t_result - y).float() + me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + + print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (8192, 8192), + (8192, 8256), # to test the nonfused RHT path + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_2D", [False, True], ids=str) +@pytest.mark.parametrize("use_RHT", [False, True], ids=str) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + use_2D: bool, + use_RHT: bool, + M: int, + N: int, +) -> None: + if x_dtype == torch.float32 and use_RHT: + pytest.skip("RHT is only supported with bfloat16") + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + use_2D=use_2D, + use_RHT=use_RHT, + M=M, + N=N, + ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 90e624c94..be7a65deb 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -32,12 +32,59 @@ reset_rng_states() model_configs = { - "small": ModelConfig(32, 2, 2, 32), + "small": ModelConfig(2, 32, 2, 32), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -278,7 +325,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) def test_make_graphed_callables( *, module: str, @@ -295,8 +342,18 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": - pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") + if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op": + pytest.skip( + f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" + ) + if fp8 and fp8_recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {fp8_recipe.__class__.__name__}" + ) + if fp8_params: + pytest.skip("NVFP4 params not supported") # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -334,17 +391,19 @@ def test_make_graphed_callables( "module", _test_make_graphed_callables_with_fp8_weight_caching_modules, ) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, + dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, - dtype=torch.float32, + dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, fp8_weight_caching=True, diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index a0d6f1fd9..82bd61a01 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -10,7 +10,6 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex -import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype @@ -273,6 +272,14 @@ def run_linear_multiple_steps( if bgrad_list is not None and bgrad is not None: bgrad_list.append(bgrad.detach().clone()) + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + @classmethod def run_linear( cls, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index bb07e87d9..440986661 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,15 +35,17 @@ Float8Quantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Import utility functions -from utils import dtype_tols, make_recipe, reset_rng_states +from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states -# Check if FP8 is supported +# Check for supported quantization schemes fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] @@ -59,6 +61,8 @@ _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: _quantization_list.append("mxfp8") +if nvfp4_available: + _quantization_list.append("nvfp4") def maybe_skip_quantization( @@ -66,6 +70,7 @@ def maybe_skip_quantization( *, dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, ) -> None: """Skip test case if a quantization scheme is not supported""" @@ -73,12 +78,17 @@ def maybe_skip_quantization( if quantization is None: return - # Check if quantization scheme is supported + # Check if quantization scheme is supported on device + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + # Check dims if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) @@ -88,10 +98,14 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + elif quantization == "nvfp4": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") - # Check if device is supported - if device is not None and torch.device(device).type != "cuda": - pytest.skip("Quantization is only supported on CUDA devices") + # Check dtype + if dtype is not None: + if quantization == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 quantization is only supported with BF16 data") @torch.no_grad() @@ -141,6 +155,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -395,12 +417,12 @@ def test_fp8_scale_update( torch.testing.assert_close( y, torch.full_like(y, y_val_ref), - **dtype_tols(tex.DType.kFloat8E4M3), + **quantization_tols("fp8_delayed_scaling"), ) torch.testing.assert_close( x.grad, torch.full_like(x.grad, dx_val_ref), - **dtype_tols(tex.DType.kFloat8E5M2), + **quantization_tols("fp8_delayed_scaling"), ) # Check that scaling factors match expected @@ -434,7 +456,8 @@ def test_dtype_cast( # Skip invalid configurations in_shape = (size, size) with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype) + maybe_skip_quantization(quantization, dtype=final_dtype) # Random data dtype = torch.float32 @@ -502,7 +525,8 @@ def test_pyt_autocast( # Skip invalid configurations in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype) + maybe_skip_quantization(quantization, dtype=autocast_dtype) # Construct operation recipe = make_recipe(quantization) @@ -558,7 +582,7 @@ def test_identity( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -624,7 +648,7 @@ def test_reshape( # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) with_quantization = quantization is not None # Random data @@ -690,7 +714,7 @@ def test_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -752,7 +776,7 @@ def test_quantize( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8": maybe_skip_quantization(quantization, dims=in_shape) @@ -819,7 +843,7 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) quantization_needed = any( ( @@ -899,7 +923,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute or quantized_output or quantized_grad_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1010,7 +1034,7 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantization is None and (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not specified") @@ -1077,7 +1101,7 @@ def test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1114,7 +1138,7 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1175,7 +1199,7 @@ def test_layer_norm( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1284,7 +1308,7 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1337,7 +1361,7 @@ def test_rmsnorm( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1417,7 +1441,7 @@ def test_add_extra_input( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x1_ref, x1_test = make_reference_and_test_tensors( @@ -1456,8 +1480,11 @@ def test_add_extra_input( # Check results tols = dtype_tols(dtype) - if with_quantization: - tols = dtype_tols(x1_test._fp8_dtype) + if in_place: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): + tols = dtype_tols(x1_test._fp8_dtype) + elif quantization == "nvfp4": + tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") @@ -1486,7 +1513,7 @@ def test_make_extra_output( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1559,7 +1586,7 @@ def test_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if cache_quantized_input: maybe_skip_quantization("fp8_current_scaling", device=device) @@ -1633,8 +1660,10 @@ def test_activation( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute or cache_quantized_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + if quantized_compute: + tols = quantization_tols(quantization) + elif cache_quantized_input: + tols = quantization_tols("fp8_current_scaling") # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1665,7 +1694,7 @@ def test_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1699,7 +1728,7 @@ def test_swiglu( # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1767,7 +1796,7 @@ def test_dropout( # Skip invalid configurations quantized_input = quantization is not None - maybe_skip_quantization(quantization, dims=shape, device=device) + maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype) # Random data # Note: Shift values to make sure inputs are non-zero @@ -1858,7 +1887,7 @@ def test_forward_linear_bias_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( @@ -1929,7 +1958,7 @@ def test_forward_linear_bias_activation( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1965,7 +1994,7 @@ def test_forward_linear_bias_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2040,7 +2069,7 @@ def test_forward_linear_bias_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2078,7 +2107,7 @@ def test_forward_linear_scale_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2146,7 +2175,7 @@ def test_forward_linear_scale_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2179,7 +2208,7 @@ def test_backward_activation_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): pytest.skip("Unsupported tensor size for MXFP8") @@ -2241,7 +2270,7 @@ def test_backward_activation_bias( # Expected numerical error tols = dtype_tols(dtype) if with_quantization: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2360,7 +2389,7 @@ def test_backward_linear_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2428,7 +2457,7 @@ def test_backward_linear_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") @@ -2463,7 +2492,7 @@ def test_backward_linear_scale( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2523,7 +2552,7 @@ def test_backward_linear_scale( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2564,7 +2593,7 @@ def test_linear( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) # Construct model @@ -2690,7 +2719,7 @@ def test_layernorm_mlp( ffn_shape = in_shape[:-1] + (ffn_hidden_size,) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=ffn_shape, device=device) quantization_needed = quantized_compute or quantized_weight if quantization is None and quantization_needed: diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 9a51c53e3..004abfd97 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -19,6 +19,7 @@ fp8_model_init, ) from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear from transformer_engine.pytorch.distributed import fp8_autocast @@ -499,3 +500,39 @@ def test_quantizer_update(self, module_class): y = module(x, [batch_size]) else: y = module(x) + + +fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available() + + +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # # largest tile + (8192, 8192), + ], +) +def test_fp4_dequantize(dtype, M, N): + q = NVFP4Quantizer() + a = torch.rand((M, N)).cuda().to(dtype=dtype) + starting_tensor = q(a) + dequantized_tensor = starting_tensor.dequantize() + new_tensor = q(dequantized_tensor) + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) + new_dequantized_tensor = new_tensor.dequantize() + torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 5151aa96e..981c58243 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -87,9 +87,19 @@ def is_fp8_supported(config: ModelConfig): "large": ModelConfig(2, 128, 4, 128, num_layers=1), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -379,6 +389,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -407,6 +419,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) @@ -437,6 +451,8 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): @@ -476,6 +492,8 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4(): + pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): @@ -526,6 +544,8 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -568,6 +588,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -629,6 +651,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -683,6 +707,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -734,6 +760,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -764,6 +792,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -798,6 +828,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -832,6 +864,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 9e90f9fda..d77256b7f 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -73,6 +73,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: # Transformer Engine dtypes if isinstance(dtype, tex.DType): + if dtype == tex.DType.kFloat4E2M1: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.25 dtype = { tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, @@ -95,10 +97,25 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: if dtype == torch.float8_e4m3fn: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 if dtype == torch.float8_e5m2: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") +def quantization_tols(name: str) -> dict[str, float]: + """Estimated numerical error for a quantization scheme""" + if name in ( + "fp8", + "fp8_delayed_scaling", + "fp8_current_scaling", + "mxfp8", + "mxfp8_block_scaling", + ): + return dtype_tols(tex.DType.kFloat8E4M3) + if name == "nvfp4": + return dtype_tols(tex.DType.kFloat4E2M1) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def make_recipe(name: Optional[str]) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: @@ -118,6 +135,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() + if name == "nvfp4": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + ) raise ValueError(f"Unsupported quantization scheme ({name})") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 08e876404..a4915080e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -53,6 +53,28 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) +# NVIDIA MathDX include directory (from Python package install location) +if(NOT DEFINED MATHDX_INCLUDE_DIR) + execute_process( + COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx + OUTPUT_VARIABLE _PIP_SHOW_MATHDX + ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR + RESULT_VARIABLE _PIP_SHOW_MATHDX_RES + OUTPUT_STRIP_TRAILING_WHITESPACE) + if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) + message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") + endif() + string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") + if(NOT _MATHDX_LOC_MATCH) + message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") + endif() + set(MATHDX_LOCATION "${CMAKE_MATCH_1}") + set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") +endif() +if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") + message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") +endif() + # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -73,6 +95,7 @@ list(APPEND transformer_engine_SOURCES transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu activation/gelu.cu dropout/dropout.cu fused_attn/flash_attn.cu @@ -85,6 +108,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp fused_attn/utils.cu + gemm/config.cpp gemm/cublaslt_gemm.cu gemm/cutlass_grouped_gemm.cu normalization/common.cpp @@ -113,6 +137,9 @@ list(APPEND transformer_engine_SOURCES recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu + recipe/nvfp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu @@ -144,7 +171,8 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 8b7f92aff..666f57188 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -39,6 +39,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { return CUDA_R_8F_E4M3; case DType::kFloat8E5M2: return CUDA_R_8F_E5M2; +#if CUDA_VERSION >= 12080 + case DType::kFloat4E2M1: + return CUDA_R_4F_E2M1; +#endif default: NVTE_ERROR("Invalid type"); } @@ -160,7 +164,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits) { + const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { @@ -169,6 +175,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, }(); // rank is the number of dimensions of the array constexpr uint32_t rank = 2; + + // Dimension for the packed data types must reflect the number of individual U# values. uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next @@ -207,7 +215,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, // Swizzling can be used to avoid shared memory bank conflicts. - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + swizzle, // L2 Promotion can be used to widen the effect of a cache-policy to a wider // set of L2 cache lines. diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index e2a3c52aa..bddd9bf19 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -48,8 +48,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } +inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + +inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } +inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -108,6 +114,7 @@ struct Tensor { SimpleTensor data; SimpleTensor columnwise_data; SimpleTensor amax; + SimpleTensor columnwise_amax; SimpleTensor scale; SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; @@ -119,6 +126,7 @@ struct Tensor { : data(), columnwise_data(), amax(nullptr, {1}, DType::kFloat32), + columnwise_amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), @@ -129,6 +137,7 @@ struct Tensor { data.clear(); columnwise_data.clear(); amax.clear(); + columnwise_amax.clear(); scale.clear(); scale_inv.clear(); columnwise_scale_inv.clear(); @@ -174,6 +183,7 @@ struct Tensor { * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { + case NVTE_NVFP4_1D_SCALING: case NVTE_DELAYED_TENSOR_SCALING: if (!has_data() && has_columnwise_data()) { std::vector ret; @@ -189,7 +199,6 @@ struct Tensor { } break; case NVTE_MXFP8_1D_SCALING: - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { @@ -261,12 +270,18 @@ struct QuantizationConfig { NVTETensor noop_tensor = nullptr; Float8BlockScaleTensorFormat float8_block_scale_tensor_format = Float8BlockScaleTensorFormat::GEMM_READY; + NVTETensor rng_state = nullptr; + bool nvfp4_2d_quantization = false; + bool stochastic_rounding = false; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float), // amax_epsilon - sizeof(NVTETensor), // noop_tensor - sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor), // noop_tensor + sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format + sizeof(NVTETensor), // rng_seed and offset + sizeof(bool), // nvfp4_2d_quantization + sizeof(bool) // stochastic_rounding }; }; @@ -298,6 +313,8 @@ using fp8e8m0 = __nv_fp8_e8m0; #endif #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif using e8m0_t = uint8_t; @@ -334,17 +351,20 @@ struct TypeExtrema; template <> struct TypeExtrema { static constexpr float max = 6.0f; + static constexpr float max_inverse = 1.0 / max; }; #endif template <> struct TypeExtrema { static constexpr float max = 448.0f; + static constexpr float max_inverse = 1.0 / max; }; template <> struct TypeExtrema { static constexpr float max = 57344.0f; + static constexpr float max_inverse = 1.0 / max; }; template <> @@ -558,6 +578,18 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +// Add a pack_size argument to select the packed type for FP4 +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat4E2M1: { \ + using type = __nv_fp4x2_storage_t; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -717,10 +749,11 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. -void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, - const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, - const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits); +void create_2D_tensor_map( + CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, + const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, + const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); bool is_supported_by_CC_100(); diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp new file mode 100644 index 000000000..cf211beaf --- /dev/null +++ b/transformer_engine/common/gemm/config.cpp @@ -0,0 +1,116 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "./config.h" + +#include +#include + +#include + +#include "../util/logging.h" + +NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; } + +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(buf, &config_.bias_tensor, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(buf, &config_.dbias_tensor, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(buf, &config_.with_gelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(buf, &config_.use_split_accumulator, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(&config_.bias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(&config_.dbias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(&config_.with_gelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(&config_.use_split_accumulator, buf, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_matmul_config(NVTEMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h new file mode 100644 index 000000000..54ccf06a5 --- /dev/null +++ b/transformer_engine/common/gemm/config.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_ +#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_ + +#include + +namespace transformer_engine { + +struct MatmulConfig { + NVTETensor bias_tensor = nullptr; + NVTETensor dbias_tensor = nullptr; + bool with_gelu_epilogue = false; + bool with_dgelu_epilogue = false; + NVTETensor epilogue_aux_tensor = nullptr; + bool use_split_accumulator = false; + int sm_count = 0; + + static constexpr size_t attr_sizes[] = { + sizeof(NVTETensor), // bias_tensor + sizeof(NVTETensor), // dbias_tensor + sizeof(bool), // with_gelu_epilogue + sizeof(bool), // with_dgelu_epilogue + sizeof(NVTETensor), // epilogue_aux_tensor + sizeof(bool), // use_split_accumulator + sizeof(int) // sm_count + }; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index f287072bc..ab80fe769 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -9,20 +9,55 @@ #include #include #include +#include #include +#include #include #include +#include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/handle_manager.h" #include "../util/logging.h" #include "../util/multi_stream.h" -#include "common/util/cuda_runtime.h" -#include "cutlass_grouped_gemm.cuh" +#include "./config.h" +#include "./cutlass_grouped_gemm.cuh" namespace { +/* Use CUDA const memory to store scalar 1 and 0 for cublas usage +*/ +__device__ __constant__ float one_device; +__device__ __constant__ float zero_device; + +inline float *GetScalarOne() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float one = 1.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), one_device)); + return dev_ptr; +} + +inline float *GetScalarZero() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float zero = 0.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), zero_device)); + return dev_ptr; +} + +__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; } + uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -82,6 +117,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool is_A_transposed = transA == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T; + // Set conditions for MXFP8 and NVFP4 gemm execution. + const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); + const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { // Unscaled or FP8 tensor scaling @@ -102,10 +141,26 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } - } else if (is_mxfp_scaling(A.scaling_mode)) { - // MXFP8 + } else if (nvfp4) { + // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. + + if (is_A_transposed) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode), + "Input A has unsupported combination of recipe and layout"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); + } + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; // NVFP4 gemm is only supported in TN layout. + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + } else if (mxfp8) { + // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { @@ -161,10 +216,20 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } } - } else if (is_mxfp_scaling(B.scaling_mode)) { - // MXFP8 - // Note: Row-wise and column-wise data are scaled along different - // dimensions (with matrix interpreted in row-major order). + } else if (nvfp4) { + if (is_B_transposed) { + NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), + "Input B has unsupported combination of recipe and layout"); + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; // NVFP4 gemm is only supported in TN layout. + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (mxfp8) { if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { @@ -221,7 +286,7 @@ using cublasHandleManager = detail::HandleManageramax.dptr != nullptr || inputB->amax.dptr != nullptr)) { + // Reserve some workspace for alpha scale + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + uint8_t *workspace_ptr = reinterpret_cast(workspace); + float *new_alpha_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + + // Update alpha scale on device + // Note: Compute NVFP4 tensor scales based on amaxes and then + // divide from alpha scale. This way we only need to apply NVFP4 + // tensor scales in matmul output, instead of in matmul inputs. + float old_alpha = *reinterpret_cast(alpha); // Assumed to be on CPU + TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector{1}, DType::kFloat32); + nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb, + old_alpha, new_alpha_tensor.data(), stream); + alpha = new_alpha_ptr; + + // Make sure beta scale is on device + float old_beta = *reinterpret_cast(beta); // Assumed to be on CPU + if (old_beta == 0) { + beta = GetScalarZero(); // Device constant memory + } else if (old_beta == 1) { + beta = GetScalarOne(); // Device constant memory + } else { + // Move beta to workspace + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + float *new_beta_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta); + NVTE_CHECK_CUDA(cudaGetLastError()); + beta = new_beta_ptr; + } + } const cudaDataType_t A_type = get_cuda_dtype(param.Atype); const cudaDataType_t B_type = get_cuda_dtype(param.Btype); @@ -270,16 +378,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "FP8 input to GEMM requires inverse of scale!"); NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); // check consistency of arguments: // if fp8 is desired, context cannot be null // fp8 + gelu fusion + fp8 aux is unavailable right now. - if (use_fp8 && gelu) { + if ((use_fp8 || use_fp4) && gelu) { NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), "fp8 Aux output for gemm + gelu fusion not supported!"); } - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); + if (is_fp4_dtype(outputD->data.dtype)) { + NVTE_ERROR("FP4 GEMM output is not supported!"); + } + if (use_fp4 && (D_type == CUDA_R_16F)) { + NVTE_ERROR("FP4 GEMM does not support FP16 output!"); } cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -319,12 +434,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &math_sm_count, sizeof(math_sm_count))); } - // set fp8 attributes -- input and output types should already be set to fp8 as appropriate - // Note: gelu fusion isn't available right now, and we don't need + // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4 + // as appropriate. Note: gelu fusion isn't available right now, and we don't need // amax(D) either (next op is high precision). - if (use_fp8) { - // Split accumulator. - const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + if (use_fp8 || use_fp4) { + // Fast accumulation is only supported for FP8. + const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); @@ -333,7 +450,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_b; #endif // CUBLAS_VERSION >= 120800 - if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { + if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) { void *A_scale_inverse = param.A_scale_inv; void *B_scale_inverse = param.B_scale_inv; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -346,7 +463,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; #endif // CUBLAS_VERSION >= 120800 - } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { + } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cublas_version() >= 120800, "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); @@ -371,6 +488,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #else NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= 120800 + } else if (use_fp4) { // NVFP4 GEMM +#if CUBLAS_VERSION >= 120800 + NVTE_CHECK(cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE + cublasDataType_t scale_type = CUDA_R_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + + // Set pointer mode: alpha and beta are both device pointers + // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); + + fp8e4m3 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e4m3 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; +#else + NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif // CUBLAS_VERSION >= 120800 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && @@ -503,14 +648,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif -#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ - CUBLAS_VERSION < 130000 +#else NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); @@ -565,16 +707,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ - param.A, /* A */ - Adesc, param.B, /* B */ - Bdesc, static_cast(&beta), /* beta */ - C, /* C */ - Cdesc, D, /* D */ - Ddesc, &heuristicResult.algo, /* algo */ - workspace, /* workspace */ - workspaceSize, stream)); /* stream */ + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */ + param.A, /* A */ + Adesc, param.B, /* B */ + Bdesc, beta, /* beta */ + C, /* C */ + Cdesc, D, /* D */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. @@ -600,35 +741,117 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, - nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); +} + +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm_v2); + using namespace transformer_engine; + + // Data tensors + const Tensor *A_tensor = convertNVTETensorCheck(A); + const Tensor *B_tensor = convertNVTETensorCheck(B); + const Tensor *C_tensor = convertNVTETensorCheck(C); + Tensor *D_tensor = convertNVTETensorCheck(D); + NVTE_CHECK(C_tensor == D_tensor, + "Currently nvte_cublas_gemm_v2 does not support different C and D tensors."); + + // Workspace + void *workspace_ptr = nullptr; + size_t workspace_size = 0; + Tensor *workspace_tensor = convertNVTETensor(workspace); + if (workspace_tensor != nullptr) { + workspace_ptr = workspace_tensor->data.dptr; + workspace_size = + get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype); + } + + // Additional config + MatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + + // Configure GEMM epilogue + const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue); + if (with_grad_epilogue) { + NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue, + "Invalid epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ")."); + } + Tensor dummy_tensor; + Tensor *epilogue_bias_tensor = &dummy_tensor; + if (!with_grad_epilogue && config_.bias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor); + } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor); + } + Tensor *epilogue_aux_tensor = &dummy_tensor; + if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) { + NVTE_CHECK(config_.epilogue_aux_tensor != nullptr, + "Requested epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor."); + epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor); + } + + // Launch GEMM + cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor, + transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N, + with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta, + config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { - NVTE_API_CALL(nvte_cublas_gemm_scaled); + NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -639,17 +862,14 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_atomic_gemm); using namespace transformer_engine; - - // Check CUDA and cuBLAS versions #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif +#else NVTE_CHECK( transformer_engine::cuda::cudart_version() >= 12020 && transformer_engine::cuda::cudart_version() < 13000, @@ -668,13 +888,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = convertNVTETensor(counter); Tensor *wspace = convertNVTETensor(workspace); + const void *alpha_ptr = GetScalarOne(); + const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero(); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, - n_split, gemm_producer, inputCounter, stream); + alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split, + gemm_producer, inputCounter, stream); +#endif } void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, @@ -695,9 +919,30 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens } for (int i = 0; i < num_gemms; i++) { - nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams)); + // Check whether GELU or dGELU epilogue is requested + Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]); + bool with_gelu_dgelu_epilogue = + (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr); + + // Construct config + MatmulConfig config; + if (grad) { + config.dbias_tensor = bias[i]; + config.with_dgelu_epilogue = with_gelu_dgelu_epilogue; + } else { + config.bias_tensor = bias[i]; + config.with_gelu_epilogue = with_gelu_dgelu_epilogue; + } + config.epilogue_aux_tensor = pre_gelu_out[i]; + config.use_split_accumulator = use_split_accumulator; + config.sm_count = math_sm_count; + + // Launch GEMM + const float alpha = 1.f; + const float beta = accumulate ? 1.f : 0.f; + nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i], + workspace[i % num_streams], &config, + detail::get_compute_stream(i % num_streams)); } // record events on compute streams diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu new file mode 100644 index 000000000..9d4bec41d --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -0,0 +1,876 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kThreadsPerWarp = 32; +constexpr float k16x16HadamardScale = 0.25f; + +template +__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr) { + auto smem_addr = static_cast(__cvta_generic_to_shared(addr)); + if constexpr (kTranspose) { + asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } else { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } +} + +template +__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr, uint32_t stride) { + if constexpr (kTranspose) { + asm volatile( + "wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } else { + asm volatile( + "wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } +} + +template +__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, void* addr, + uint32_t stride) { + if constexpr (kTranspose) { + asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } else { + asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) { + asm volatile( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) + : "r"(a0)); +} + +__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) { + __nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16); + float f_a = __bfloat162float(bf16x2.x); + float f_b = __bfloat162float(bf16x2.y); + asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b)); + float_dst = fabsf(float_dst); +} + +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc( + uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1, + uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3, + uint32_t& amax_result) { + uint32_t zero = 0; + uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; + asm volatile( + "wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n" + "{%0, %1, %2, %3, %4, %5, %6, %7}, \n" + "{%8, %9, %10, %11}, \n" + "{%12, %13, %14, %15}, \n" + "{%16, %17, %18, %19, %20, %21, %22, %23};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6), + "=r"(temp7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero), + "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6)); + if constexpr (kCalculateAmax) { + uint32_t max_even; + uint32_t max_odd; + // Reduction tree to amax(abs(result)) into bf16x2 reg outparam. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2)); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3)); + // N.B. mma is only called up to once per thread for identity and transpose respectively, so + // we don't have to accumulate into amax_result and can directly store into it. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(amax_result) + : "r"(max_even), "r"(max_odd)); + } +} + +template +__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i, + uint16_t random_sign_mask, + uint32_t* had_frag_t, + uint16_t random_sign_mask_t) { + int32_t tid = threadIdx.x % 32; // Local tid + float temp_i[2]; + float temp_t[2]; +#pragma unroll + for (int i = 0; i < 2; i++) { + // i is the vertical fragment index. + // For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals. + uint32_t r = i * 8 + tid / 4; + +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int k = 0; k < 2; k++) { + // k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits. + // j is the column fragment idx selecting between even and odd fragments. + // j increments 8 columns by switching fragments. + uint32_t c = j * 8 + k + tid % 4 * 2; + // 1 -> -1.0f, 0 -> 1.0f + int32_t base_sign = __popc(r & c); + if constexpr (kReturnIdentity) { + int32_t sign_i; + // Because tensor cores want the dot product dimension, + // contiguous, the regular, non-inverse hadamard swaps + // signs of columns and rows for inverse. In a simple reference, + // x.reshape(-1, 16) @ sign @ H16, this would be opposite but + // (sign @ H16) is transposed in this fragment. + if constexpr (kInverseHadamardIdentity) { + sign_i = ((random_sign_mask >> r) ^ base_sign); + } else { + sign_i = ((random_sign_mask >> c) ^ base_sign); + } + temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31)); + } + if constexpr (kReturnTransposed) { + int32_t sign_t; + if constexpr (kInverseHadamardTransposed) { + sign_t = ((random_sign_mask_t >> r) ^ base_sign); + } else { + sign_t = ((random_sign_mask_t >> c) ^ base_sign); + } + temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31)); + } + } + + if constexpr (kReturnIdentity) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_i[i * 2 + j]) + : "f"(temp_i[1]), "f"(temp_i[0])); + } + if constexpr (kReturnTransposed) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_t[i * 2 + j]) + : "f"(temp_t[1]), "f"(temp_t[0])); + } + } + } +} + +__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, + uint32_t gmem_col_idx) { + uint32_t smem_row_idx = gmem_row_idx; + uint32_t xor_factor = (smem_row_idx * 2) % 8; + uint32_t smem_col_idx = gmem_col_idx ^ xor_factor; + return smem_row_idx * 8 + smem_col_idx; +} + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr) { + if (output_pre_rht_amax_ptr != nullptr) { + *output_pre_rht_amax_ptr = 0; + } + if (output_identity_amax_ptr != nullptr) { + *output_identity_amax_ptr = 0; + } + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = 0; + } +} + +template +__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output, + T* __restrict__ output_t, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, uint64_t num_input_rows, + uint64_t num_input_cols, float* __restrict__ amax, + float* __restrict__ amax_t, bool inverse_hadamard) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported."); + + // The whole threadblock will share the same smem. + extern __shared__ __align__(16) T smem[]; + + // Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16. + // If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices. + int32_t tid = threadIdx.x; + int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z; + int32_t local_bx = threadIdx.y; + int32_t local_by = threadIdx.z; + + // Define the register fragments + uint32_t a_frag[4]; // A matrix fragment + uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major) + uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major) + uint32_t c_frag[4]; // Result fragment + + // row and col for each thread. 32 threads will work together in 128 chunk to + // load the data from global memory to shared memory. + uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4)); + uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4)); + + uint32_t smem_index = tid; + + uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension; + uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension; + + bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows); + if (!load) { + // Out of bound, we are returning early. No thread divergence since the whole warp + // will return early. + return; + } + + uint64_t global_offset = input_start_col + input_start_row * num_input_cols; + uint64_t global_offset_t = + kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset; + + T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id; + + uint32_t* smem_b32 = reinterpret_cast(base_smem); + uint4* smem_b128 = reinterpret_cast(base_smem); + + // Asynchronously load the data from global memory to shared memory. + const uint4* input_b128 = reinterpret_cast(input + global_offset); + // Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each + // 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4 + // to load the data in the tensor core swizzled format. + __pipeline_memcpy_async(&smem_b128[smem_index], + &input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col], + sizeof(uint4)); + __pipeline_commit(); // Commit the memcpy. Wait when we are in the computation. + + if (inverse_hadamard) { + get_hadamard_matrix_fragment(b_frag_i, random_sign_mask, + b_frag_t, random_sign_mask_t); + } else { + get_hadamard_matrix_fragment( + b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t); + } + + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + __pipeline_wait_prior(0); + + __syncwarp(); // ensure all lanes finished their cp.async before reading smem + + // Load the A to a_frag. + if constexpr (kComputeIdentity) { + load_matrix_16x16_from_shared(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32, + kHadamardDimension); + + // 16x16 @ 16x16 leveraging all threads in the warp. + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnIdentity) { + uint4* output_b128 = reinterpret_cast(output + global_offset); + store_matrix_16x16_to_global(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128, + num_input_cols); + } + } + + if constexpr (kComputeTransposed) { + if (kComputeIdentity) { + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + } else { + load_matrix_16x16_from_shared(a_frag[0], + a_frag[2], // NOTE: intentional index swapping + a_frag[1], // NOTE: intentional index swapping + a_frag[3], smem_b32, kHadamardDimension); + } + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], + // 2,1 is used if we are using movmatrix instruction. + // Thus loading the matrix in 2,1 order will just be normal. + // This is to be compatible with the movmatrix instruction. + a_frag[2], // NOTE: intentional index swapping for transpose purpose. + a_frag[1], // NOTE: intentional index swapping for transpose purpose. + a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1], + c_frag[2], c_frag[3], local_amax_t_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnTransposed) { + uint4* output_t_b128 = reinterpret_cast(output_t + global_offset_t); + store_matrix_16x16_to_global( + c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128, + kOutputTrueTransposed ? num_input_rows : num_input_cols); + } + } + + if constexpr (kUpdateIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + local_amax = warp_reduce_max(local_amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax, local_amax); + } + } + if constexpr (kUpdateTransposeAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + local_amax_t = warp_reduce_max(local_amax_t); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax_t, local_amax_t); + } + } +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+."); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +} + +} // namespace + +void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform); + + // Check tensors + // NOTE (frsun): This is non-intuitive, we are writing the result of + // transposed RHT to the output of rowwise. + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be simple tensor, but scaling mode is ", + to_string(output_.scaling_mode), "."); + const SimpleTensor& input = input_.data; + SimpleTensor output; + SimpleTensor& output_t = output_.data; + + // Check requested outputs + const bool return_identity = output.dptr != nullptr; + const bool return_transposed = output_t.dptr != nullptr; + if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior. + return; + } + + checkCuDriverContext(stream); + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + using IType = bf16; + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kThreadBlockX = 4; + // Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth. + constexpr uint64_t kThreadBlockY = 4; + + uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY; + + // The shared memory number of bytes required for **the whole threadblock**. + size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM; + + dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY); + + dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX), + DIVUP(num_rows / kHadamardDimension, kThreadBlockY)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed, kReturnTransposed, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + auto kernel = + HadamardTransformKernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes); + + kernel<<>>( + reinterpret_cast(input.dptr), reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), random_sign_mask, random_sign_mask_t, + num_rows, row_length, nullptr, nullptr, false););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then +// get the absolute max value of the result. +void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_amax); +#if CUDA_VERSION >= 12080 + + // Check input tensor + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor& input = input_.data; + + // Check amax tensors + SimpleTensor& output_pre_rht_amax = output_.amax; + SimpleTensor output_identity_amax; + SimpleTensor& output_transpose_amax = output_.columnwise_amax; + + // Check requested outputs + const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr; + const bool return_identity_amax = output_identity_amax.dptr != nullptr; + const bool return_transposed_amax = output_transpose_amax.dptr != nullptr; + if (!return_identity_amax && !return_transposed_amax && + !return_pre_rht_amax) { // Nothing to do/ill-defined behavior. + return; + } + + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kChunkBlockXSmall = 128; + constexpr uint64_t kChunkBlockYSmall = 128; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input, + /*globalY=*/num_rows, + /*globalX=*/row_length, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/row_length, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + + dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed_amax, kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity_amax, kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = HadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>( + tensor_map_input, reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, + random_sign_mask_t, num_rows, row_length);))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform); + using namespace transformer_engine; + hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} + +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_amax); + using namespace transformer_engine; + hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu new file mode 100644 index 000000000..ce191b5ff --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -0,0 +1,841 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "curanddx.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/helper_cuda.hpp" +#include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +// Define a cuRANDDx descriptor +// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. +// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., +// if shared memory, if needed, is enough for the described problem, usually not applicable. + +// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread()); + + +using namespace cute; +using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +// calculate the global encode scale factor for a given global amax. +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float kFP8E4M3Max = 448.0f; + constexpr float kFP4E2M1Max = 6.0f; + // If scale is infinity, return max value of float32 + float global_encode_scale = cutlass::minimum_with_nan_propagation{}( + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + // If global amax is 0 or infinity, return 1 + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +struct SharedStorage { + static constexpr int AccumulatorPipelineStageCount = 16; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + +}; + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + auto output_ptr = reinterpret_cast(&output); + asm volatile( \ + "{\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ + "}" \ + : "=h"(output_ptr[0]), + "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), + "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), + "r"(rbits[0]), "r"(rbits[1])); +#else + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const *rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +__global__ static +void +rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, + TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TC * C, CStride dC, CSmemLayout , + TSFC * SFC, + TiledMMA mma, + float const* global_amax, + const size_t* rng_state) +{ + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); + Tensor mC = make_tensor(cute::subbyte_iterator(C), make_shape(M,N), dC); // (M,N) + + auto sfc_shape = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), N / 64 ) + ); + + auto sfc_stride = make_stride( + N / 16, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + auto sfc_layout = make_layout(sfc_shape, sfc_stride); + Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout); + + auto cluster_shape = Shape< _1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + + auto mainloop_tiler = Shape<_128,_16,_64>{}; + auto epilogue_tiler = Shape<_128,_64,_64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler)); + + auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma, + Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue, + Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition(tma_load_a, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition(tma_load_b, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + if (is_epilogue_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(global_amax)); + } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,tile_idx_m,_); + int k_tile = 0; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage)); + } + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ) + { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) + { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + const float global_amax_val = *global_amax; + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N) + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + + const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + const float global_decode_scale = 1.0f / global_encode_scale; + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile); + + Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout( + make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}) + )); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + // Cast data from FP32 to BF16 to FP32. + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + + // Initialize RNG for tile + const size_t rng_sequence + = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = dist.generate4(rng); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale + ), + reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}(cutlass::multiplies>{}(compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } +} + +// this function computes RHT-GEMM for +// A: m x n: col-major +// B: 16 x 16: row-major +// C: m x n: row-major +// SFC: m x (n/16): row-major +template +void +rht_gemm_ntt_w_sfc(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 2048) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = static_cast(m); + auto N = static_cast(n); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dB = make_stride(Int<1>{}, 16); // (dN,dK) + auto dC = make_stride(n, Int<1>{}); // (dM,dN) + + auto cga_shape = Shape< _1, _1, _1>{}; + auto cga_tile_shape = Shape<_128,_16,_16>{}; + auto cluster_tile_mainloop = Shape<_128,_16,_64>{}; + + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + + // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy + + // Create GMEM tensors + Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16) + + // Create the TiledCopy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cga_tile_shape, + mma); + + // Assert checks on tile sizes -- no predication + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, + "Inner dimension must be divisible by ", static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, + "Outer dimension must be divisible by ", 4 * static_cast(size<1>(cga_tile_shape)), + " but got ", N, "."); + + uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); + + tiles = (tiles < sm_count) ? tiles : sm_count; + + dim3 dimBlock(256); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles, 1, 1); + + int smem_size = sizeof(SharedStorage); + auto* kernel_ptr = &rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TC, decltype(dC), decltype(sC), + TSFC, + decltype(mma), + kEnableStochasticRounding>; + + bool status = cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (status != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return; + } + (*kernel_ptr) + <<< dimGrid, dimBlock, smem_size, stream >>> + (M, N, k_tile_size, cga_tile_shape, + A, dA, sA, tma_load_a, + B, dB, sB, tma_load_b, + C, dC, sC, + SFC, + mma, global_amax, + rng_state); +} + +// this function is used to wrap the rht_gemm_ntt_w_sfc function +//to transpose the input tensor A +template +void +rht_gemm_ttt_wrapper(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) +{ + // in addition to transpose the input tensor A + // we also need to reshape m, n to at best + // ultilize as many SMs as possible while keeping + // a relatively large contiguous dimension. + // for example, after swapping m, n for transpose purposes, + // the input / output tensor shapes for RHT-GEMM are: + // A: n x m: col-major + // B: 16 x 16: row-major + // C: n x m: row-major + // SFC: n x (m/16): row-major + rht_gemm_ntt_w_sfc( + n, m, + A, B, C, + SFC, global_amax, + rng_state, + sm_count, stream, + k_tile_size); +} + +} // namespace +} // namespace detail + +// clang-format on + +void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, + QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); + + // Check input and output tensors + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + SimpleTensor &global_amax = output_.amax; + SimpleTensor &output_t = output_.data; + SimpleTensor &scale_inv_t = output_.scale_inv; + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size);); +} + +} // namespace transformer_engine + +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise); + using namespace transformer_engine; + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + hadamard_transform_cast_fusion_columnwise( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0c358328b..950014cc9 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -15,9 +15,76 @@ #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus -/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. +/*! \brief Configuration for matrix multiplication. */ +typedef void *NVTEMatmulConfig; + +/*! \enum NVTEMatmulConfigAttribute + * \brief Type of option for matrix multiplication. + */ +enum NVTEMatmulConfigAttribute { + /*! Bias tensor + * + * If provided, the bias tensor is applied in the GEMM epilogue. + */ + kNVTEMatmulConfigBiasTensor = 0, + /*! Bias gradient tensor + * + * If provided, the bias gradient tensor will be filled in the GEMM epilogue. + */ + kNVTEMatmulConfigDBiasTensor = 1, + /*! Whether to compute GELU in GEMM epilogue. */ + kNVTEMatmulConfigWithGELUEpilogue = 2, + /*! Whether to compute GELU backward in GEMM epilogue. */ + kNVTEMatmulConfigWithDGELUEpilogue = 3, + /*! Auxilliary tensor for GEMM epilogue. + * + * For GELU, this will be filled with the GELU input. For GELU + * backward, this is expected to already be filled with the GELU + * input. + */ + kNVTEMatmulConfigEpilogueAuxTensor = 4, + /*! Whether to use split accumulator for FP8 GEMM. */ + kNVTEMatmulConfigUseSplitAccumulator = 5, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEMatmulConfigSMCount = 6, + kNVTEMatmulConfigNumAttributes +}; + +/*! \brief Create a matrix multiplication configuration. */ +NVTEMatmulConfig nvte_create_matmul_config(); + +/*! \brief Query an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a matrix multiplication configuration. */ +void nvte_destroy_matmul_config(NVTEMatmulConfig config); + +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -44,8 +111,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. + * + * Computes: + * - `D = alpha * op(A) * op(B) + beta * C` + * + * \param[in] transa Whether to transpose A matrix. + * \param[in] transb Whether to transpose B matrix. + * \param[in] alpha Scaling factor applied to matmul output. + * \param[in] A A matrix. + * \param[in] B B matrix. + * \param[in] beta Scaling factor applied to C matrix. + * \param[in] C C matrix. + * \param[out] D Output matrix. + * \param[in] workspace Workspace tensor. + * \param[in] config Additional configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, - * allowing for using a scaling factor for the GEMM result and the accumulation input + * allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated) + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -133,14 +223,16 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, - bool transa, bool transb, bool grad, NVTETensor* workspace, +void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); #ifdef __cplusplus } // extern "C" -#endif +#endif // __cplusplus + +#ifdef __cplusplus /*! \namespace transformer_engine */ @@ -153,6 +245,89 @@ namespace transformer_engine { void nvte_cublas_handle_init(); +/*! \struct MatmulConfigWrapper + * \brief C++ wrapper for NVTEMatmulConfig. + */ +class MatmulConfigWrapper { + public: + MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {} + + MatmulConfigWrapper(const MatmulConfigWrapper &) = delete; + MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete; + + MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~MatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEMatmulConfig. + * + * \return NVTEMatmulConfig held by this MatmulConfigWrapper. + */ + operator NVTEMatmulConfig() const noexcept { return config_; } + + /*! \brief Set bias tensor. */ + void set_bias_tensor(NVTETensor bias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set bias gradient tensor. */ + void set_dbias_tensor(NVTETensor dbias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to compute GELU in GEMM epilogue. */ + void set_with_gelu_epilogue(bool with_gelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue, + &with_gelu_epilogue, sizeof(bool)); + } + + /*! \brief Set whether to compute GELU backward in GEMM epilogue. */ + void set_with_dgelu_epilogue(bool with_dgelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue, + &with_dgelu_epilogue, sizeof(bool)); + } + + /*! \brief Set auxilliary tensor for GEMM epilogue. */ + void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor, + &epilogue_aux_tensor, sizeof(NVTETensor)); + } + + /*! \brief Set whether to use split accumulator for FP8 GEMM. */ + void set_use_split_accumulator(bool use_split_accumulator) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator, + &use_split_accumulator, sizeof(bool)); + } + + /*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */ + void set_sm_count(int sm_count) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int)); + } + + private: + /*! \brief Wrapped NVTEMatmulConfig. */ + NVTEMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine +#endif // __cplusplus + #endif // TRANSFORMER_ENGINE_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h new file mode 100644 index 000000000..a0dd325da --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -0,0 +1,68 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file hadamard_transform.h + * \brief Functions for Hadamard transforms. + */ + +#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ +#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Perform a randomized Hadamard transform on the input tensor. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the absolute maximum reduction on the input tensor with/without + * randomized hadamard transform. The rowwise result is the absolute maximum + * of the input tensor. The columnwise result is the absolute maximum of the + * input tensor transposed and applied randomized hadamard transformation. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the columnwise hadamard transform cast fusion. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] hadamard_matrix Hadamard matrix. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 2fc8c1095..6e1e9dd7a 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -122,6 +122,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, size_t start_offset, size_t block_len, const NVTEDType out_dtype, cudaStream_t stream); +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index dab4fcfe7..1a901ab82 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -66,6 +66,7 @@ enum NVTETensorParam { kNVTEAmax = 3, /*!< Amax tensor */ kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTENumTensorParams }; @@ -88,10 +89,9 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, - /*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), - and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). - */ - NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, + /*! Single scale per block of 16 elements consecutive in either + * rowwise or columnwise direction */ + NVTE_NVFP4_1D_SCALING = 4, NVTE_INVALID_SCALING = 100 }; @@ -330,6 +330,12 @@ enum NVTEQuantizationConfigAttribute { * likely be refactored away in the future. */ kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, + /*! RNG state (NVTETensor with 2 elements - seed and offset */ + kNVTEQuantizationConfigRNGState = 4, + /*! Whether to use 2D block scaling for NVFP4 */ + kNVTEQuantizationConfigNVFP42DQuantization = 5, + /*! Whether to enable stochastic rounding */ + kNVTEQuantizationConfigStochasticRounding = 6, kNVTEQuantizationConfigNumAttributes }; @@ -431,6 +437,15 @@ inline bool is_fp8_dtype(const DType t) { */ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } +/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16) + * + * Return true if TE datatype is high precision + * \param[in] DType TE Datatype of interest + */ +inline bool is_high_precision_dtype(const DType t) { + return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16; +} + /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. */ @@ -566,6 +581,11 @@ class TensorWrapper { return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); } + template + TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -590,6 +610,10 @@ class TensorWrapper { return get_parameter(kNVTEColumnwiseScaleInv); } + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEColumnwiseAmax); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -838,6 +862,24 @@ class QuantizationConfigWrapper { &format, sizeof(Float8BlockScaleTensorFormat)); } + /*! \brief Set stochastic rounding state */ + void set_rng_state(NVTETensor rng_state) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to use 2D block scaling for NVFP4 */ + void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization, + &nvfp4_2d_quantization, sizeof(bool)); + } + + /*! \brief Set whether to use stochastic rounding */ + void set_stochastic_rounding(bool stochastic_rounding) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding, + &stochastic_rounding, sizeof(bool)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 398c0acbd..5785fd223 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -28,7 +28,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -63,7 +63,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 82e360ed6..a3b05f7a2 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -24,7 +24,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -49,7 +49,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index fc8d73a13..ea0287ef1 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -4,7 +4,6 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations -import warnings import os from enum import Enum from typing import Literal, Optional, Union, Callable, NamedTuple @@ -23,9 +22,12 @@ class _FormatHelper(NamedTuple): class Format(Enum): """ Supported FP8 formats. + Supported FP4 formats. Values ------ + E2M1 : + All FP4 tensors are in e2m1 format E4M3 : All FP8 tensors are in e4m3 format E5M2 : @@ -35,6 +37,7 @@ class Format(Enum): FP8 tensors in the backward pass are in e5m2 format """ + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) @@ -42,9 +45,13 @@ class Format(Enum): @dataclass(frozen=True) class MMParams: - """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) - apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, - so only turn it on for certain gemms + """Matrix multiplication options. + + Parameters + ---------- + use_split_accumulator : bool, default = `True` + Use FP8 fast accumulation on Hopper or Ada. For more details, + see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul. """ use_split_accumulator: bool = True @@ -55,10 +62,24 @@ class QParams: """Quantization parameters. power_2_scale: use power of 2 scale parameter amax_epsilon: optional minimum value of abs max + random_hadamard_transform: whether to use random hadamard transform + stochastic_rounding: whether to use stocastic rounding """ power_2_scale: bool = False amax_epsilon: float = 0.0 + random_hadamard_transform: bool = False + stochastic_rounding: bool = False + fp4_2d_quantization: bool = False + + def __repr__(self) -> str: + return ( + f"Qparams(\npower_2_scale={self.power_2_scale},\n" + f"amax_epsilon={self.amax_epsilon},\n" + f"random_hadamard_transform={self.random_hadamard_transform},\n" + f"stochastic_rounding={self.stochastic_rounding},\n" + f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + ) class Recipe: @@ -66,6 +87,10 @@ class Recipe: Base recipe class. """ + def nvfp4(self): + """Whether the given recipe is NVFP4 1D block scaling.""" + return isinstance(self, NVFP4BlockScaling) + def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" return isinstance(self, MXFP8BlockScaling) @@ -351,3 +376,84 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class NVFP4BlockScaling(Recipe): + """ + Use the NVFP4 scaling strategy. + + This is a 2-level block scaling strategy. In level 1, each group of + 16 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E4M3 (4 bits of exponent, + 3 bits of mantissa). In level 2, a global per tensor FP32 scaling + factor is used to scale the entire tensor. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. + + Parameters + ---------- + fp4_format : {Format.E2M1}, default = Format.E2M1 + FP4 data type. + fp8_format : {Format.E4M3}, default = Format.E4M3 + FP8 data type. Only E4M3 is supported. + fp8_dpa: bool, default = `False` + FP8 dot product attention. Not yet supported. + fp8_mha: bool, default = `False` + FP8 multi-head attention. Not yet supported. + """ + + # Configuration envvars + disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1" + disable_stochastic_rounding: bool = ( + os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" + ) + disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + + fp4_format: Format = Format.E2M1 + fp8_format: Format = Format.E4M3 + + # Not applying quantization to attention for now + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" + assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + + # Quantization params + # Note: RHT is currently only applied to column-wise usage so that + # it can be used for wgrad GEMM. + self.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=False, + fp4_2d_quantization=False, + ) + self.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, + stochastic_rounding=False, + fp4_2d_quantization=not self.disable_2d_quantization, + ) + self.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=not self.disable_stochastic_rounding, + fp4_2d_quantization=False, + ) + + def __repr__(self) -> str: + return ( + f"recipe_type={self.__class__.__name__}, " + f"fp4_format={str(self.fp4_format).split('.')[1]}, " + f"fp8_format={str(self.fp8_format).split('.')[1]}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}, " + f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " + f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " + f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " + ) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index fd907efcb..ee2c84515 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -20,6 +20,13 @@ namespace { constexpr int amax_kernel_threads = 512; +__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *amax_ptr = 0; +} + template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, const size_t N, @@ -65,7 +72,8 @@ template void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr, cudaStream_t stream) { // Zero out amax so we can update with atomic max - NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); + zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); // Return immediately if tensor is empty if (N == 0) { @@ -130,15 +138,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *convertNVTETensorCheck(output_); - NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || + output.scaling_mode == NVTE_NVFP4_1D_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling or " + "NVFP4 1D scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); NVTE_CHECK(output.amax.numel() == 1, "Output tensor for amax computation has invalid amax tensor " "(expected 1 entry, got shape=", output.amax.shape, ")"); - NVTE_CHECK(output.amax.dptr != nullptr, + NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr, "Output tensor for amax computation has amax tensor without data"); NVTE_CHECK(output.amax.dtype == DType::kFloat32, "Output tensor for amax computation has invalid amax tensor " @@ -157,11 +167,12 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt } // Compute amax + float *amax_ptr = reinterpret_cast( + (output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); - launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), - noop_ptr, stream);); // NOLINT(*) + input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel( + reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr, + stream);); // NOLINT(*) } } // anonymous namespace diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu new file mode 100644 index 000000000..5ebc7ba4f --- /dev/null +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { +namespace nvfp4_recipe { + +// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; +constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); + +// Kernel to compute alpha *= amax_A * amax_B / factor +__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, + const float *amax_B, float *alpha_out) { + // factor is defined in the enclosing namespace + *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; +} + +} // namespace nvfp4_recipe +} // namespace transformer_engine + +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale); + using namespace transformer_engine; + + auto *tA = convertNVTETensor(inpA); + auto *tB = convertNVTETensor(inpB); + auto *tOut = convertNVTETensor(alpha_out); + + void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; + void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; + void *alpha_ptr = tOut->data.dptr; + + // check for not null pointers + NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); + NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null"); + NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null"); + + nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( + alpha_in, reinterpret_cast(amax_A_ptr), + reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 9ec86a37c..36e06173d 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -18,7 +18,9 @@ namespace transformer_engine { namespace { -constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr int MXFP8_BLOCK_SIZE = 32; +constexpr int NVFP4_BLOCK_SIZE = 16; + constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -314,8 +316,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int original_K = kernel_args.original_k_list[tensor_id]; constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); - constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; - constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; // Get block index in grid. Emulate 2D grid. const int num_tiles_k = K / SF_TILE_DIM_K; @@ -332,9 +332,13 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); - } + NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING || + input->scaling_mode == NVTE_BLOCK_SCALING_1D || + input->scaling_mode == NVTE_BLOCK_SCALING_2D || + input->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), + "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); // Do nothing if tensor is empty if (input->data.numel() == 0) { @@ -345,123 +349,150 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s CheckInputTensor(*output, "scaling_factor_output"); auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Unsupported scaling mode for swizzling."); + + bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; // 1D block scaling, row-wise or colum-wise - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int m = - input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; - const int k = - input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); - NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); + int m, k; + if (input->has_data()) { + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else { + if (nvfp4) { + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } else { + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; } + } - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; - dim3 block_size(TB_DIM, TB_DIM); - if (input->has_data()) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + // For NVFP4, the scale inverse for tranposed data needs rowwise swizzle. + const bool rowwise_swizzle = input->has_data() || nvfp4; + const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4; + + dim3 block_size(TB_DIM, TB_DIM); + if (rowwise_swizzle) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + int original_M, original_K; + void *input_scale_inv_ptr, *output_scale_inv_ptr; + + if (!nvfp4 || input->has_data()) { + int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + original_M = input->flat_first_dim(); + original_K = input->flat_last_dim() / block_scale_size; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + } else { + original_M = input->flat_last_dim(); + original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; } - if (input->has_columnwise_data()) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 2: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 1: - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; } + } + if (columnwise_swizzle) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); - // 2D block scaling - } else { - NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -551,6 +582,8 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, } NVTE_CHECK_CUDA(cudaGetLastError()); } + +// TODO(nvfp4): Add NVFP4 support. void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); @@ -677,7 +710,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, * WIP (Phuong): * - Opt for bank conflicts * - Adding swizzle for 2d-block scaling. -*/ + */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_scaling_factors); using namespace transformer_engine; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 55654989a..f49fe239a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include "common.h" #include "common/util/cuda_runtime.h" @@ -63,8 +64,8 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: - return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; + case NVTE_NVFP4_1D_SCALING: + return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: return "NVTE_INVALID_SCALING"; } @@ -94,12 +95,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { t.columnwise_scale_inv.shape, ")"); } } else { - if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || - t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; size_t expected_x, expected_y, alignment; - const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; + const size_t block_size_rowwise = 32; const size_t block_size_colwise = 32; if (t.has_data()) { @@ -110,6 +110,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", @@ -122,11 +123,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { alignment; alignment = block_alignment[0]; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } + } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { + if (t.has_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } } } } @@ -154,6 +173,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { "(expected Float32 or Byte, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // TODO(ksivaman): Fix this to check for amaxes and other details. + // For now only needed for swizzle. + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); @@ -195,10 +234,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt "(expected Float32 or Float8E8M0, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // FP4 output needs to have the scale_inv + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); - // Note: amax is supported for non-FP8 output as it can be fused into the computation - // and later used for quantization with no need to compute it separately + // Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax. + // NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); @@ -491,6 +549,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, case kNVTEColumnwiseScaleInv: t->columnwise_scale_inv = *param; break; + case kNVTEColumnwiseAmax: + t->columnwise_amax = *param; + break; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -514,6 +575,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p return t.scale_inv; case kNVTEColumnwiseScaleInv: return t.columnwise_scale_inv; + case kNVTEColumnwiseAmax: + return t.columnwise_amax; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -629,6 +692,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); break; + case kNVTEQuantizationConfigRNGState: + std::memcpy(&config_.rng_state, buf, attr_size); + break; + case kNVTEQuantizationConfigNVFP42DQuantization: + std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size); + break; + case kNVTEQuantizationConfigStochasticRounding: + std::memcpy(&config_.stochastic_rounding, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index abfa226e8..89266f4bb 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine::detail { @@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor const bool pow_2_scale, const SimpleTensor &noop_tensor, cudaStream_t stream); +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv, + SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor &noop_tensor, cudaStream_t stream); + } // namespace transformer_engine::detail #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu new file mode 100644 index 000000000..eced2c4bb --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -0,0 +1,842 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "curanddx.hpp" + +namespace transformer_engine { + +#if CUDA_VERSION >= 12080 +namespace quantize_transpose_nvfp4 { +namespace { + +using std::int32_t; +using std::uint32_t; +using std::uint8_t; + +using transformer_engine::detail::TypeExtrema; + +// Define a cuRANDDx descriptor +// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. +// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., +// if shared memory, if needed, is enough for the described problem, usually not applicable. +// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr int kThreadsPerWarp = 32; + +// for fp4, we use uint8_t to store 2 fp4 numbers +constexpr int kNFP4PerContainer = 2; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; +// constexpr int kScaleDim = 32; +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16 +constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 +// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +// for 2D block scaling, we need to reduce amax in warp +static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; + +// max for every group_size elements in warp +template +__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { + for (int offset = group_size / 2; offset > 0; offset /= 2) { + val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride)); + } + return val; +} + +template +__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax, + const float global_encode_scale) { + float decode_scale = amax / TypeExtrema::max; + decode_scale = decode_scale * global_encode_scale; + decode_scale = fminf(decode_scale, TypeExtrema::max); + return static_cast(decode_scale); +} + +template +__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale, + const float global_decode_scale) { + return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), + TypeExtrema::max); +} + +template +__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) { + return static_cast(input) * encode_scale; +} + +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float fp8_max = TypeExtrema::max; + constexpr float fp4_max = TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.f || global_encode_scale == 0.f) { + return 1.f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +template +__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx, + uint32_t col_length) { + // This function takes in indices from the scale factor matrix and returns an offset in the + // swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled + // index). col_length is the column length of the scale factor matrix. tile_scales_inv is the + // pointer to the scale factor matrix. + + // https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts + // For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4 + // columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4 + // columns. + + // NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix. + // To think in high level, the swizzled scale factor matrix could be composed as: + // unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8) + // cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64 + // rb_cnt = M // 128 # Assuming M is divisible by 128 + // tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4) + // tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + // swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4)) + + constexpr uint32_t kTotalRowsPerBaseBlock = 128; + constexpr uint32_t kRowsPerBaseBlockCol = 32; + constexpr uint32_t kColsPerBaseBlockCol = 4; + + const size_t rb = row_idx / kTotalRowsPerBaseBlock; + const size_t rem = row_idx % kTotalRowsPerBaseBlock; + const size_t d4 = rem / kRowsPerBaseBlockCol; + const size_t d3 = rem % kRowsPerBaseBlockCol; + const size_t cbg = col_idx / kColsPerBaseBlockCol; + const size_t d5 = col_idx % kColsPerBaseBlockCol; + + const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol); + // row-major offset in the logical shape + // (rb_cnt , cbg_cnt , 32 , 4 , 4) + // Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] / + // 32 = [0-4]) + return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const uint32_t rbits) { +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + uint16_t out_4x; + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + "}" + : "=h"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const uint32_t rbits) { +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + // NOTE: rbits unused for rn. + uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + asm volatile( + "{\n" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL +} + +template +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const uint32_t rbits) { + if constexpr (kApplyStochasticRounding) { + return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); + } else { + return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits); + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, const float* global_amax, OType* const output_c, + OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, + const size_t row_length, const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, + const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state, + const float* noop_ptr) { + constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer; + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + + const size_t block_idx_x = blockIdx.x; + const size_t block_idx_y = blockIdx.y; + const size_t rng_sequence = + threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. + // Instead of static_assert, return early if these invalid modes are detected. + if constexpr (kIs2DBlockScaling && kIsE8Scaling) { + return; + } + if constexpr (kIs2DBlockScaling && !kReturnIdentity) { + return; + } + // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 + // use constexpr to define the size, when not using 2D, use minimal size 1x1 + constexpr int kFP4BlockScalingSize = 16; + constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1; + constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4 + constexpr int k2DBlockAmaxReduceDim = + kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1; + __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; + __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = (c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) + : 0); // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory + // for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; + const float global_encode_scale = + kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); + const float global_decode_scale = 1.0 / global_encode_scale; + + // Step 2: Cast and store to output_c + if constexpr (kReturnIdentity) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = + (c_g < row_length ? min(static_cast(kNVecOut / kNFP4PerContainer), + (row_length - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // doing shuffle sync for 2D block scaling (not applicable for E8 scaling) + if constexpr (kIs2DBlockScaling) { + // first amax shuffle sync in warp, then reduce in smem + // T0 T8 T16 T24 should do amax reduction together + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + // now T0 ~ T8 in each warp has the reduced amax values + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = fmaxf(amax_2d, + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + // every thread now knows 2D amax + amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; + } + // Step 2.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + write_scale_inv &= (c_g < row_length); + } + if (write_scale_inv) { + size_t row_idx = block_idx_y * kTileDim + r_s; + size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce; + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(row_length, kScaleBlockDim)); + tile_scales_inv_c[offset] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); + const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory + // for not aligned case) + output_g += stride_g / kNFP4PerContainer; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory + const size_t stride_g = + static_cast(c_stride) * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = (c_g < num_rows ? min(static_cast(kNVecOut / kNFP4PerContainer), + (num_rows - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; + if constexpr (kIs2DBlockScaling) { + // TODO(zhongbo): 2D block scaling, directly read from amax_smem + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + constexpr int kNumColsPerWarp = + kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements + constexpr int kNumWarpsPerBlock = + kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block + constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock; + int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp; + int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore; + int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x; + amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize]; + } else { +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + } + // Step 3.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // Step 3.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + write_scale_inv &= (c_g < num_rows); + } + if (write_scale_inv) { + size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce); + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim)); + tile_scales_inv_t[offset] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = + ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], + encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], + encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], + encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0, + num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global + // memory for not aligned case) + output_g += stride_g / kNFP4PerContainer; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace quantize_transpose_nvfp4 +#endif // CUDA_VERSION >= 12080 + +namespace detail { + +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor& noop_tensor, cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); +#if CUDA_VERSION >= 12080 + + // pow 2 scale is for MXFP4 since it's using E8M0 scaling + // raise error if pow2_scale is true + NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now"); + + if (!return_identity && !return_transpose) { + return; + } + + if (use_2d_quantization && !return_identity) { + return; + } + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + + if (return_identity) { + scale_stride_x = 1; + scale_stride_y = scale_inv.shape[1]; + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + using namespace transformer_engine::quantize_transpose_nvfp4; + + const size_t num_blocks_x = DIVUP(row_length, static_cast(kTileDim)); + const size_t num_blocks_y = DIVUP(num_rows, static_cast(kTileDim)); + + // noop tensor for cuda graph + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + + const size_t* rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY( + output.dtype, 2, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16; + constexpr bool kPow2Scale = false; + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + swizzled_scale, kSwizzledScale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kApplyStochasticRounding, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, + num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, + noop_ptr);) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace detail +} // namespace transformer_engine diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 50ff82d85..6093b54b6 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -598,6 +598,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { @@ -828,6 +829,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; @@ -947,6 +949,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; @@ -1260,7 +1263,7 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_gated(gated_input, output, stream); } } - } else if (is_mxfp_scaling(output->scaling_mode)) { + } else if (is_mxfp8_scaling(output->scaling_mode)) { if (use_tma_kernels) { cast_mxfp8_gated(grad, gated_input, output, stream); } else { diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 8d8735118..b0498602b 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -23,6 +23,7 @@ #include "../util/vectorized_pointwise.h" #include "../utils.cuh" #include "math.h" +#include "nvfp4_transpose.cuh" #include "ptx.cuh" #include "transformer_engine/transformer_engine.h" @@ -108,6 +109,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -135,8 +138,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_sh = reinterpret_cast(dshmem); IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; @@ -284,7 +288,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float scaled_out = in * block_scale_inverse; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); } } @@ -408,10 +412,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; @@ -439,7 +445,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); } } @@ -454,19 +460,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); } if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); } // Create a "bulk async-group" out of the previous bulk copy operation. @@ -487,18 +493,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Added extra 1-element padding per thread_X to reduce bank conflicts float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const size_t shmem_thread_offset = + const int shmem_thread_offset = tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; - const size_t shmem_elt_idx = swizzled_group_offset + e; + const int shmem_elt_idx = swizzled_group_offset + e; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } @@ -506,15 +512,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int i = 0; i < THREADS_Y; ++i) { // Add extra element offset per MXFP8 scaling block [1x32] - const size_t scaling_block = threadIdx.x / SCALE_DIM_X; + const int scaling_block = threadIdx.x / SCALE_DIM_X; thread_partial_dbias += partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } - const size_t dbias_stride = cols; - const size_t dbias_offset_Y = blockIdx.y; - const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); if (!col_out_of_bounds_dbias) { dbias_workspace[dbias_idx] = thread_partial_dbias; @@ -536,6 +542,528 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } } // namespace mxfp8_kernel +namespace nvfp4_kernel { + +using namespace ptx; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace nvfp4_kernel + constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_THREADS_PER_CHUNK = 128; @@ -898,7 +1426,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, } template -static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { const size_t N = product(input.data.shape); const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); @@ -1179,6 +1707,141 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ); // NOLINT(*) } +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +template +void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { + using namespace nvfp4_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_nvfp4_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_nvfp4_kernel + <<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + }); // NOLINT(*) + ); // NOLINT(*) +} + namespace detail { using Empty = transformer_engine::Empty; @@ -1386,20 +2049,33 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o auto dbias_tensor = convertNVTETensor(dbias); auto workspace_tensor = convertNVTETensor(workspace); - const QuantizationConfig *quant_config_cpp = - reinterpret_cast(quant_config); + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } - // extract noop tensor from quant_config_cpp if it's not null - const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + // Dispatch to quantization kernel depending on data format switch (output_tensor->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { if (output_tensor->has_columnwise_data()) { NVTE_CHECK(output_tensor->has_data(), "Quantizing in only the columnwise direction not supported yet!"); if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); } else { cast_transpose_fused( *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, @@ -1407,51 +2083,90 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o } } else if (output_tensor->has_data()) { fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); } break; } case NVTE_MXFP8_1D_SCALING: { mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, + *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); break; } + case NVTE_NVFP4_1D_SCALING: { + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && + output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4_quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4_quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), + "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for " + "2D quantization"); + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } case NVTE_BLOCK_SCALING_2D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor.data, stream); + /*noop_tensor=*/noop_tensor->data, stream); break; } case NVTE_BLOCK_SCALING_1D: { // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; } if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); columnwise_option = columnwise_compact ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -1459,7 +2174,7 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_vector_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor.data, stream); + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } default: diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e2d8d34f3..9f70ce4cd 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include "../common.h" @@ -26,6 +28,7 @@ #include "math.h" #include "ptx.cuh" #include "transformer_engine/activation.h" +#include "transformer_engine/transformer_engine.h" #include "transformer_engine/transpose.h" namespace transformer_engine { @@ -226,7 +229,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); @@ -247,7 +250,7 @@ static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t str ); // NOLINT(*) } -static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); @@ -331,6 +334,81 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } + +#if CUDA_VERSION >= 12080 +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // CUDA_VERSION + +void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if CUDA_VERSION >= 12080 + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // CUDA_VERSION >= 12080 +} + } // namespace dequantization namespace detail { @@ -339,17 +417,25 @@ void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { - if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + dequantization::fp8_dequantize(input, output, stream); + break; } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + dequantization::mxfp8_dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + dequantization::fp4_dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } } diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh new file mode 100644 index 000000000..fe9736298 --- /dev/null +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -0,0 +1,1515 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_transpose.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ +#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ + +#include +#include +#include + +#if CUDA_VERSION > 12080 +#include +#endif // CUDA_VERSION > 12080 + +#include + +#include "../common.h" +#include "../utils.cuh" +#include "curanddx.hpp" +#include "math.h" +#include "ptx.cuh" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { + +#if CUDA_VERSION > 12080 +namespace nvfp4_transpose { + +using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + + curanddx::SM<800>() + curanddx::Thread()); + +using namespace ptx; +using nvfp4_scale_t = fp8e4m3; + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; +constexpr size_t RNG_GENS_PER_THREAD = + SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +} + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + curanddx::uniform_bits dist; + random_uint4 = dist.generate4(rng); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + // NOTE: rbits unused for rn. + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); +#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} + +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + +template +__global__ void __launch_bounds__(THREADS_NUM) + nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, + const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, + rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, + const float *noop, const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + RNG rng(rng_seed, rng_sequence, rng_offset); + curanddx::uniform_bits dist; + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + // NEW: 2D Block-based scaling constants + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile + constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + const size_t warp_id = threadIdx.x / 32; + const size_t lane_id = threadIdx.x % 32; + float thread_amax = 0.0f; + const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + + // Helper function for warp reduction + auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { +#pragma unroll + for (int delta = 8; delta >= 1; delta /= 2) { + float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); + thread_amax = fmaxf(thread_amax, other_amax); + } + return thread_amax; + }; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + +#pragma unroll + for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + const size_t block_in_tile_y = block_iter; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + for (int elem = 0; elem < BLOCK_DIM; elem += 2) { + const size_t elem_0_row = block_iter * BLOCK_DIM + elem; + const size_t elem_1_row = elem_0_row + 1; + const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + const size_t elem_1_col = elem_0_col; + + const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; + const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; + + IType2 val_2x; + val_2x.x = in_sh[shmem_offset_0]; + val_2x.y = in_sh[shmem_offset_1]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); + } + + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else { + for (int elem = 0; elem < BLOCK_DIM; ++elem) { + const size_t elem_row = block_iter * BLOCK_DIM + elem; + const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + + // Bounds checking + const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); + const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); + if (!row_out_of_bounds && !col_out_of_bounds) { + const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; + float elt = static_cast(in_sh[shmem_offset]); + + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset] = static_cast(elt); + } + + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + } + // Warp reduction to get block amax + block_amax = warp_reduce_amax(thread_amax, block_in_warp); + + if (lane_id == 0 || lane_id == 16) { + block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; + } + } + + // sync thread to ensure block_amax_matrix is done storing + __syncthreads(); + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 3. Scale elements + + // Load data in + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + } + } else { + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, + rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = tid_X_rowwise; + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + } + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace nvfp4_transpose +#endif // CUDA_VERSION > 12080 + +// Compile-time flag to choose kernel variant +#ifndef USE_2D_NVFP4_KERNEL +#define USE_2D_NVFP4_KERNEL 0 +#endif + +template +void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if CUDA_VERSION > 12080 + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + // TODO(Frank): Is there a better way to do this? + bool return_transpose = output->has_columnwise_data(); + + using namespace nvfp4_transpose; + using namespace ptx; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + if (return_transpose) { + NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = output->columnwise_scale_inv.shape[1]; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = nvfp4_transpose_kernel; + + if constexpr (use_2d_quantization) { + kernel = nvfp4_transpose_kernel_2D; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION > 12080 +} +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 581de9f9f..85717afdf 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -14,6 +14,10 @@ #include #include +#if CUDA_VERSION >= 12080 +#include +#endif // CUDA_VERSION >= 12080 + namespace transformer_engine { namespace ptx { @@ -117,9 +121,13 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { return __int_as_float(biased_exp << FP32_MANTISSA_BITS); } +#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \ + ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM103_ALL))) + __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) +#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL + uint16_t out; asm volatile( "{\n" @@ -222,18 +230,86 @@ struct alignas(2 * sizeof(T)) FPx2 { T y; }; +template +struct FPx4 { + T x1; + T x2; + T x3; + T x4; +}; + +template +struct Type2x {}; + +template <> +struct Type2x { + using type = float2; +}; + +template <> +struct Type2x { + using type = __nv_bfloat162; +}; + +template <> +struct Type2x { + using type = __half2; +}; + using floatx2 = FPx2; using bf16x2 = FPx2; using fp16x2 = FPx2; using fp8e4m3x2 = FPx2; using fp8e5m2x2 = FPx2; +using floatx4 = FPx4; +using bf16x4 = FPx4; +using fp16x4 = FPx4; +using fp8e4m3x4 = FPx4; +using fp8e5m2x4 = FPx4; + static_assert(sizeof(floatx2) == 8); static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); +#if CUDA_VERSION >= 12080 +using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; +static_assert(sizeof(fp4e2m1x2) == 1); +static_assert(sizeof(fp4e2m1x4) == 2); +#endif // CUDA_VERSION >= 12080 + +// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1 + +// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6. + +// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures: +// sm_100a +// sm_101a +// sm_120a + +// When converting to .e2m1x2 data formats, the destination operand d has .b8 type. +// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, +// and the converted values are packed in the destination operand d such that the value +// converted from input a is stored in the upper 4 bits of d and the value converted +// from input b is stored in the lower 4 bits of d. + +// SIMD like "Fused" cast + multiplication (x4) +#if CUDA_VERSION >= 12080 +template +__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, + const float scale) { + const float x0 = static_cast(in01.x) * scale; + const float x1 = static_cast(in01.y) * scale; + const float x2 = static_cast(in23.x) * scale; + const float x3 = static_cast(in23.y) * scale; + out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); +} +#endif // CUDA_VERSION >= 12080 + // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { @@ -369,7 +445,7 @@ __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const "r"(reinterpret_cast(p2))); } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index 68b7aa8bb..bce124e70 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -22,7 +22,8 @@ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 3f5bcc975..bc764ac74 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -35,6 +35,26 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// +// Device-side error +#define NVTE_DEVICE_ERROR(message) \ + do { \ + printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \ + __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \ + (message)); \ + assert(0); \ + } while (false) + +// Device-side error on thread 0 +#define NVTE_DEVICE_THREAD0_ERROR(message) \ + do { \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \ + threadIdx.y == 0 && threadIdx.z == 0) { \ + NVTE_DEVICE_ERROR(message); \ + } \ + } while (false) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) return {a.x + b.x, a.y + b.y}; } diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index d1470e22e..a1fae730c 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -89,3 +89,5 @@ dist_group_type = torch.distributed.ProcessGroup MXFP8_BLOCK_SCALING_SIZE = 32 + +NVFP4_BLOCK_SCALING_SIZE = 16 diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index e4f4e619f..d330e023e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -13,6 +13,8 @@ from ..tensor.quantized_tensor import Quantizer from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.utils import is_experimental +from ..experimental.gemm import experimental_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ @@ -77,6 +79,24 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + # If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation + if is_experimental(A) or is_experimental(B): + return experimental_gemm( + A, + B, + workspace, + out_dtype, + quantization_params, + gelu, + gelu_in, + accumulate, + layout, + out, + bias, + use_split_accumulator, + grad, + ) + debug_quantizer = None if isinstance(quantization_params, DebugQuantizer): debug_quantizer = quantization_params diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index dffb899f7..49ae963d7 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,6 +12,20 @@ namespace transformer_engine::pytorch { +/*! convert fp4 data shape back to original shape */ +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { + std::vector ret; + size_t start_idx = (transpose) ? 1 : 0; + for (size_t i = start_idx; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() * 2); + if (transpose) { + ret.push_back(shape.front()); + } + return ret; +} + std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { @@ -291,4 +305,20 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, + arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, + at::cuda::getCurrentCUDAStream()); + }); +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) { + at::PhiloxCudaState philox_args; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 2d35de852..c94bd0d2a 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -194,20 +195,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - /*! @brief Construct a high precision tensor giving it this quantizer's amax - - Note: this member function also zeros out the amax, as it is meant to be used in conjunction with - a kernel computing the amax, which might expect the amax to be initialized to zero + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. */ - std::pair create_hp_tensor_with_amax(const std::vector& shape, - DType dtype); + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype); std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - /*! @brief Convert to a quantized data format avoiding amax computation */ + /*! @brief Quantize to FP8, skipping local amax computation + * + * The quantizer's amax pointer is assumed to already hold the local + * amax. The amax may still be reduced across the amax reduction + * group. + */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt); @@ -277,6 +283,60 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; +class NVFP4Quantizer : public Quantizer { + public: + // fp4 dtype + DType dtype; + // amax reduction for low precision FP4 AG + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + // random hadamard transform + bool with_rht; + bool with_post_rht_amax; + // 2D block scaling + bool with_2d_quantization; + bool stochastic_rounding; + + int rht_matrix_random_sign_mask_t; + at::Tensor rht_matrix; + + explicit NVFP4Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype); + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Quantize to NVFP4, skipping local amax computation + * + * The input tensor's amax pointer is assumed to already hold the + * local amax. The amax may still be reduced across the amax + * reduction group. + */ + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + + private: + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, bool compute_amax); +}; + std::unique_ptr convert_quantizer(py::handle quantizer); std::vector getTensorShape(const at::Tensor& t); @@ -420,6 +480,15 @@ std::vector convertShape(const NVTEShape& shape); size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); + +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); + +// unpack the PhiloxCudaState into CUDA tensor +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); + } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 7851cc5ff..cdfb4be40 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -8,179 +8,269 @@ #include "common.h" #include "pybind.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { +namespace { + +py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t), + const at::Tensor& input, py::handle quantizer, + int shape_divisor = 1) { init_extension(); // Input tensor auto input_tensor = input.contiguous(); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct output tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape = input_cpp.shape(); + const auto input_shape = input_nvte.shape(); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); output_shape.back() /= shape_divisor; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); - // Compute activation + // Choose implementation + enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation directly - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation in high-precision fused together with amax, then quantize. - - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); - } else { - // Compute activation in high-precision, then quantize - - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp->quantize(temp_cpp, out_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + quantizer_cpp->quantize(temp_nvte, out_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation directly + { + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { +py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, + cudaStream_t), + const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer) { init_extension(); // Grad output and input tensors auto grad_output_tensor = grad_output.contiguous(); auto input_tensor = input.contiguous(); - const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_cpp.shape(); + const auto input_shape_te = input_nvte.shape(); const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); - // Compute activation backward + // Choose implementation + enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation backward directly - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation backward in high-precision fused together with amax, then quantize. - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); - } else { - // Compute activation backward in high-precision, then quantize - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation backward in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_nvte, grad_input_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation backward directly + { + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return grad_input_py; } -/* GELU and variants*/ +} // namespace + +/* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_gelu, input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dgelu, grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_geglu, input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dgeglu, grad, input, quantizer); } py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_qgelu, input, quantizer); } py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dqgelu, grad, input, quantizer); } py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_qgeglu, input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dqgeglu, grad, input, quantizer); } -/* ReLU and variants*/ +/* ReLU and variants */ py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_relu, input, quantizer); } py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_drelu, grad, input, quantizer); } py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_reglu, input, quantizer, 2); } py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dreglu, grad, input, quantizer); } py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_srelu, input, quantizer); } py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsrelu, grad, input, quantizer); } py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_sreglu, input, quantizer, 2); } py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsreglu, grad, input, quantizer); } -/* Silu and variants*/ +/* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_forward(nvte_silu, input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dsilu, grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_forward(nvte_swiglu, input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return activation_backward(nvte_dswiglu, grad, input, quantizer); } -} // namespace transformer_engine::pytorch + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 8179727e5..5db9dd73d 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); } -void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, - arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, - at::cuda::getCurrentCUDAStream()); - }); -} - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) { - at::PhiloxCudaState philox_args; - std::lock_guard lock(gen->mutex_); - philox_args = gen->philox_cuda_state(elts_per_thread); - return philox_args; -} - } // namespace namespace transformer_engine::pytorch { @@ -198,7 +182,7 @@ std::vector fused_attn_fwd( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack(philox_args, static_cast(rng_state.data_ptr())); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index a80cb35f2..0531596dd 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -122,13 +122,27 @@ std::vector dact_dbias( } // Choose implementation - enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; + enum class Impl { + UNFUSED, + FUSED_DACT_DBIAS_QUANTIZE, + FUSED_DACT_AMAX_FP8, + FUSED_DACT_AMAX_NVFP4 + }; Impl impl = Impl::UNFUSED; if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || detail::IsMXFP8Quantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - impl = Impl::FUSED_DACT_AMAX; + impl = Impl::FUSED_DACT_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_DACT_AMAX_NVFP4; + } } // Perform compute @@ -172,20 +186,38 @@ std::vector dact_dbias( }); break; } - case Impl::FUSED_DACT_AMAX: - // Fused dact-amax kernel, unfused dbias and quantize + case Impl::FUSED_DACT_AMAX_FP8: + // Fused dact-amax kernel, unfused dbias and FP8 quantize { - auto *quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(quantizer_cpp_cs != nullptr, + auto *fp8_quantizer_cpp = + dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Invalid quantizer for fused dact-amax kernel impl"); auto [temp_nvte, temp_py] = - quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + break; + } + case Impl::FUSED_DACT_AMAX_NVFP4: + // Fused dact-amax kernel, unfused dbias and NVFP4 quantize + { + auto *nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, + "Invalid quantizer for fused dact-amax kernel impl"); + auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax( + grad_input_nvte, grad_output_dtype); NVTE_SCOPED_GIL_RELEASE({ dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); }); const auto temp_torch = temp_py.cast(); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } default: diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 0d18a5ec5..136459751 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -213,6 +213,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + if (grad) { + config.set_dbias_tensor(bias_tensor.data()); + config.set_with_dgelu_epilogue(gelu); + } else { + config.set_bias_tensor(bias_tensor.data()); + config.set_with_gelu_epilogue(gelu); + } + config.set_epilogue_aux_tensor(te_pre_gelu_out.data()); + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sms); + // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; auto main_stream = at::cuda::getCurrentCUDAStream(); @@ -276,10 +289,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), - bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, - te_workspace.data(), alpha, *beta, use_split_accumulator, - num_math_sms, main_stream); + nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(), + out_tensor.data(), out_tensor.data(), te_workspace.data(), config, + main_stream); }); } } else { diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index c63f892ce..3fa0fb0aa 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -66,67 +66,102 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); - TensorWrapper bias_cu; + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); + TensorWrapper bias_nvte; if (bias.has_value()) { - bias_cu = makeTransformerEngineTensor(*bias); + bias_nvte = makeTransformerEngineTensor(*bias); } // Tensor dimensions - const size_t N = static_cast(input_cu.size(0)); - const size_t H = static_cast(input_cu.size(1)); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - TensorWrapper mu_cu = makeTransformerEngineTensor(mu); - TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; + } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; } } - TensorWrapper unquantized_out_cu; + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -138,24 +173,31 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + default: { } } - return {out, py::cast(mu), py::cast(rsigma)}; + return {out, py::cast(mu_py), py::cast(rsigma_py)}; } std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -254,61 +296,95 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); // Tensor dimensions - const size_t N = static_cast(input_cu.shape().data[0]); - const size_t H = static_cast(input_cu.shape().data[1]); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; + } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; } } - TensorWrapper unquantized_out_cu; + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -320,24 +396,30 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; + default: { } } - return {out, py::none(), py::cast(rsigma)}; + return {out, py::none(), py::cast(rsigma_py)}; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 7649ccb6d..98f71f9a7 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -32,6 +32,9 @@ PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; +PyTypeObject *NVFP4TensorPythonClass = nullptr; +PyTypeObject *NVFP4TensorBasePythonClass = nullptr; +PyTypeObject *NVFP4QuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -86,10 +89,26 @@ void init_float8blockwise_extension() { "Internal error: could not initialize pyTorch float8blockwise extension."); } +void init_nvfp4_extensions() { + if (NVFP4TensorPythonClass) return; + auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); + NVFP4QuantizerClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); + NVFP4TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); + auto nvfp4_base_module = + py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); + NVFP4TensorBasePythonClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase")); + NVTE_CHECK(NVFP4TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch NVFP4 extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); + init_nvfp4_extensions(); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9fd1ae4de..f46edaa70 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -40,13 +40,12 @@ extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; +extern PyTypeObject *NVFP4TensorPythonClass; +extern PyTypeObject *NVFP4TensorBasePythonClass; +extern PyTypeObject *NVFP4QuantizerClass; void init_extension(); -void init_float8_extension(); - -void init_mxfp8_extension(); - namespace detail { inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } @@ -69,11 +68,17 @@ inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } +inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; } + inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; } +inline bool IsNVFP4Tensor(PyObject *obj) { + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass; +} + TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); template @@ -88,6 +93,8 @@ std::unique_ptr CreateMXFP8Params(const py::handle params); TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantization_params); +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -100,8 +107,9 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, - NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; - + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), + std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, + CreateQuantizer)}; } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index cd7e70fec..2abe9614e 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -31,8 +31,20 @@ std::vector make_transpose_shape(const std::vector& shape) { return ret; } +/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ +template +std::vector convert_shape_for_fp4(const std::vector& shape) { + std::vector ret; + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() / 2); + return ret; +} + } // namespace +constexpr size_t NVFP4_BLOCK_SIZE = 16; constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -376,8 +388,9 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( - const std::vector& shape, DType dtype) { +std::pair +Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, + DType dtype) { amax.zero_(); auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), @@ -899,7 +912,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1095,7 +1108,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); std::vector scale_shape; @@ -1116,4 +1129,573 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->with_rht = quantizer.attr("with_rht").cast(); + this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); + this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); + this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + + // Get amax reduction group if needed for NVFP4 AG + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + c10::intrusive_ptr amax_reduction_group; + if (with_amax_reduction) { + auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); + NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group"); + amax_reduction_group = group.cast>(); + } + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + + this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast(); + this->rht_matrix = quantizer.attr("rht_matrix").cast(); +} + +void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { + // set dtype for rowwise and columnwise data in tensor wrapper + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(this->dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(this->dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { + using namespace pybind11::literals; + + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; + const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_rowwise = at::empty({1}, bit32_tensor_opts); + } + if (columnwise_usage) { + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data_tensor = + at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + amax_columnwise = at::empty({1}, bit32_tensor_opts); + } + + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage); + auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage); + + // Construct Python NVFP4 tensor + py::object out_py; + if (internal) { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorBasePythonClass)); + out_py = NVFP4TensorClass( + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } else { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); + out_py = NVFP4TensorClass( + "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + rowwise_scale_inv_shape); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + columnwise_scale_inv_shape); + out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype) { + // Construct tensor + auto shape = convertShape(quantized_tensor.shape()); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + + // Register amax pointer from quantized tensor + void* amax_ptr = quantized_tensor.amax(); + if (amax_ptr == nullptr) { + amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + } + NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); + out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + + // Zero out amax + NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + auto amax_rowwise = get_tensor("_amax_rowwise"); + auto amax_columnwise = get_tensor("_amax_columnwise"); + NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); + + // Tensor dimensions, shape means original shape + std::vector shape; + if (columnwise_data) { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + if (rowwise_data) { + auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + } + + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + if (!amax_rowwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_rowwise = at::empty({1}, opts); + tensor.attr("_amax_rowwise") = *amax_rowwise; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + if (amax_rowwise) { + amax_rowwise.reset(); + tensor.attr("_amax_rowwise") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + if (!amax_columnwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + amax_columnwise = at::zeros({1}, opts); + tensor.attr("_amax_columnwise") = *amax_columnwise; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + if (amax_columnwise) { + amax_columnwise.reset(); + tensor.attr("_amax_columnwise") = py::none(); + } + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, + bool compute_amax) { + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); + quant_config.set_stochastic_rounding(this->stochastic_rounding); + + // We only need RHT for columnwise usage. + // flat first dim and last dim for multi dimensional input + size_t rows = 1; + for (size_t i = 0; i < input.ndim() - 1; ++i) { + rows *= input.size(i); + } + size_t cols = input.size(input.ndim() - 1); + + TensorWrapper te_rng_state; + if (this->stochastic_rounding) { + const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, opts); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); + te_rng_state = makeTransformerEngineTensor(rng_state); + quant_config.set_rng_state(te_rng_state.data()); + } + + // Restriction for the RHT cast fusion kernel. + bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + + // Compute amax. + if (this->with_rht) { + if (input.dtype() != DType::kBFloat16) { + NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + } + if (this->with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_amax(input.data(), out.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } else { + // raise error since it's not supported yet + NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + } + } else { // Without RHT + if (compute_amax) { + // Amax pointers + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + // Compute amax of input tensor + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Make sure row-wise and column-wise amaxes match + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } + + // amax reduction + if (this->with_amax_reduction) { + std::vector amax_tensors; + // push amax tensors inside if they need to be reduced + auto make_amax_tensor = [](void* data_ptr) { + return at::from_blob( + data_ptr, std::vector{1}, + [](void*) {}, // deleter doing nothing since it doesn't own the data + at::device(at::kCUDA).dtype(torch::kFloat32)); + }; + if (rowwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr)); + } + if (columnwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr)); + } + c10d::AllreduceCoalescedOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + NVTE_SCOPED_GIL_RELEASE( + { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); + } + + if (this->with_rht) { + if (rowwise_usage) { + // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise + TensorWrapper out_identity(out.scaling_mode()); + auto out_identity_data = out.get_rowwise_data(); + auto out_identity_scale_inv = out.get_rowwise_scale_inv(); + auto out_identity_amax = out.get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); + } + + if (columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = out.get_columnwise_data(); + auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = out.get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(out.scaling_mode()); + // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail + // need to convert the shape to 2D here + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte + // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again + // so the multiple 2 get cancelled out + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + + if (!eligible_for_rht_cast_fusion) { + // Invoking fallback RHT kernel. + + // If using RHT, then amax will be computed in the RHT step + // If not using RHT, then amax will be computed based on input x + at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout + // This wrapper is going to be passed as input to the quantization kernel. + TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + // NOTE (frsun): This is non-intuitive, we are writing the + // result of transposed RHT to the output of rowwise. + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + + NVTE_SCOPED_GIL_RELEASE({ + // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + + // Quantize kernel will treat everything as rowwise input/output, which is + // intended. + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); + }); + } else { + // RHT cast fusion kernel. + NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, + "RHT matrix is not set"); + auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_cast_fusion_columnwise( + input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); + }); + } + } + } else { + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); + } +} + +void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true); +} + +void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + // Update output tensor amaxes with input tensor amax + auto input_amax_ptr = input.amax(); + auto output_rowwise_amax_ptr = out.get_amax().data_ptr; + auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + NVTE_CHECK(input_amax_ptr != nullptr || + (output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr), + "Input tensor does not have pre-computed amax"); + if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr && + output_rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr && + output_columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + + // Perform quantization + this->quantize_impl(input, out, std::nullopt, false); +} + +std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + auto last_dim = shape.back(); + auto flat_first_dim = numel / last_dim; + + NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = roundup(flat_first_dim, 128); + size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = roundup(last_dim, 128); + size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index cb2121a45..368e9dcdf 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer return ret; } +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp4_dtype").cast(); + + auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); + + // Row-scaled data + if (rowwise_usage) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); + ret.set_rowwise_data(data.data_ptr(), dtype, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); + ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); + } + + // Column-scaled data + if (columnwise_usage) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); + ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, + getTensorShape(scale_inv)); + ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + getTensorShape(amax_columnwise)); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 92f2d3a50..3bb6be715 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -14,22 +14,31 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } - NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8, + "4-bit or 8-bit input required for swizzling scaling factors."); + + const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING; NVTEBasicTensor scale_inv; + NVTEShape nvte_input_shape; if (rowwise) { + nvte_input_shape = input.shape(); scale_inv = input.get_rowwise_scale_inv(); } else { + nvte_input_shape = input.get_columnwise_data().shape; scale_inv = input.get_columnwise_scale_inv(); } - auto input_shape = nvte_shape_to_vector(input.shape()); + auto input_shape = nvte_shape_to_vector(nvte_input_shape); auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape."); + // Allocate memory for swizzled output. auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); std::vector scale_inv_shape_int; @@ -41,36 +50,34 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); // Reconstruct input only to avoid swizzling both directions if not needed. - // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + // The specific dtype used is irrelevant, just needs to be correct bits. + transformer_engine::TensorWrapper input_cu(input.scaling_mode()); + transformer_engine::TensorWrapper output_cu(input.scaling_mode()); + + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } return swizzled_scale_inv; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 217cb98c7..3ab0717d0 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -39,11 +39,14 @@ from .fp8 import FP8GlobalStateManager, fp8_autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensor, Quantizer +from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer from .tensor._internal.float8_tensor_base import Float8TensorBase from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .triton.pad import pad_columnwise_scale_inv from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer @@ -1204,6 +1207,245 @@ def _all_gather_fp8_blockwise( return out, handle +def _swap_first_dims(tensor: torch.Tensor, world_size: int): + """ + Swap first 2 dimensions of a tensor to fix interleaved + data format after gathering transposed data. + + For more than 2 dimensions, we squash the trailing dimensions, + instead of the first few dimensions, that's because the shape + passed in this function is already transposed. + """ + + shape = tensor.shape + assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." + first_dim = shape[0] + flattened_trailing = math.prod(shape[1:]) + assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) + tensor = tex.swap_first_dims(tensor, out=None) + return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) + + +def _post_process_nvfp4_gather( + out: NVFP4TensorBase, + columnwise_data_interleaved: torch.Tensor, + columnwise_scale_inv_interleaved: torch.Tensor, + world_size: int, + handle: Optional[torch.distributed.Work] = None, +) -> NVFP4TensorBase: + """Post-process FP8 blockwise gather.""" + if handle is not None: + handle.wait() + handle = None + + # Fix the interleaved transposed data from gathering along first dim. + out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) + out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + + # Optionally pad the scaling inverse if needed. + out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + + +@dataclass +class _NVFP4AllGatherAsyncHandle: + """Handle for asynchronous NVFP4 all-gather.""" + + output: NVFP4TensorBase + columnwise_data_interleaved: torch.Tensor + columnwise_scale_inv_interleaved: torch.Tensor + world_size: int + async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + _post_process_nvfp4_gather( + self.output, + self.columnwise_data_interleaved, + self.columnwise_scale_inv_interleaved, + self.world_size, + ) + self._synchronized = True + + +def _all_gather_nvfp4( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: NVFP4Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]: + """All-gather NVFP4 tensor along first dimension.""" + + # Input tensor attributes + in_shape: Iterable[int] = None + in_shape_t: Iterable[int] = None + device: torch.device + dtype: torch.dtype + + # Construct packed shapes for input and input_t. + if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase): + # High-precision tensor. + in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size()) + in_shape_t = NVFP4Quantizer.convert_shape_for_fp4( + NVFP4Quantizer.get_columnwise_shape(inp.size()) + ) + device = inp.device + dtype = inp.dtype + elif isinstance(inp, NVFP4TensorBase): + if inp._rowwise_data is not None: + in_shape = inp._rowwise_data.size() + device = inp._rowwise_data.device + if inp._columnwise_data is not None: + in_shape_t = inp._columnwise_data.size() + device = inp._columnwise_data.device + dtype = torch.bfloat16 + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, " + f"found {inp.__class__.__name__})" + ) + + assert in_shape is not None or in_shape_t is not None, "No data found." + + world_size = get_distributed_world_size(process_group) + + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # For cases where inp has dimensions that cannot be quantized, + # we gather in high precision followed by a cast to NVFP4. + if ( + not isinstance(inp, NVFP4TensorBase) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + # Cast input tensor to NVFP4 with required data + if not isinstance(inp, NVFP4TensorBase): + inp = quantizer(inp) + elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( + quantizer.columnwise_usage and inp._columnwise_data is None + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to NVFP4." + ) + inp = quantizer(inp.dequantize()) + + # Construct NVFP4 output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Coalesce NCCL collectives for gathering data and scale inverses. + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as gather_coalescing_manager: + + # Gather NVFP4 data for row-wise usage + if quantizer.rowwise_usage: + + # Remove padding from NVFP4 scale-inverses + assert in_shape is not None, "Shape not found." + in_scale_inv = inp._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_rowwise = inp._amax_rowwise + + # Gather the transposed NVFP4 data along first dimension. Fix format later. + if quantizer.columnwise_usage: + + # Remove padding from NVFP4 scale-inverses + # For doing an all-gather on transposed scale inverses, + # we need to remove padding from both dimension. + in_scale_inv = inp._columnwise_scale_inv + # take caution that for in_shape_t, flatten in the trailing dimensions! + flattened_in_shape0 = in_shape_t[0] + flattened_in_shape1 = math.prod(in_shape_t[1:]) + + # Remove dim0 padding + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + + # Remove dim1 padding (pack first). + unpadded_dim1 = flattened_in_shape1 * 2 // 16 + if in_scale_inv.size(1) != unpadded_dim1: + in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous() + + # Construct tensor to gather transposed scale_inv (interleaved) and launch AG. + out_scale_inv = torch.empty( + [flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]], + dtype=in_scale_inv.dtype, + layout=in_scale_inv.layout, + device=in_scale_inv.device, + ) + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + + # Construct tensor to gather transposed data (interleaved) and launch AG. + out_columnwise_data = torch.empty( + [inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]), + dtype=inp._columnwise_data.dtype, + layout=inp._columnwise_data.layout, + device=inp._columnwise_data.device, + ) + torch.distributed.all_gather_into_tensor( + out_columnwise_data, + inp._columnwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_columnwise = inp._amax_columnwise + + handle = gather_coalescing_manager if async_op else None + + # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. + if async_op and quantizer.columnwise_usage: + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, handle + ) + elif quantizer.columnwise_usage: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + + return out, handle + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1291,7 +1533,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1315,7 +1556,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1347,7 +1587,7 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(inp, QuantizedTensor): + if quantizer is not None and not isinstance(inp, QuantizedTensorBase): inp = quantizer(inp) return inp, None @@ -1426,13 +1666,24 @@ def gather_along_first_dim( out_shape=out_shape, ) + # NVFP4 case + if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer): + assert isinstance(quantizer, NVFP4Quantizer) + return _all_gather_nvfp4( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorBase): inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format @@ -1450,7 +1701,7 @@ def gather_along_first_dim( return out, None # Dequantize quantized tensor if not supported - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorBase): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." diff --git a/transformer_engine/pytorch/experimental/__init__.py b/transformer_engine/pytorch/experimental/__init__.py new file mode 100644 index 000000000..11658f636 --- /dev/null +++ b/transformer_engine/pytorch/experimental/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Experimental features and APIs.""" + +from .config import set_qlinear_params, get_experimental_quantizers + + +__all__ = ["set_qlinear_params", "get_experimental_quantizers"] diff --git a/transformer_engine/pytorch/experimental/config.py b/transformer_engine/pytorch/experimental/config.py new file mode 100644 index 000000000..fec6bc938 --- /dev/null +++ b/transformer_engine/pytorch/experimental/config.py @@ -0,0 +1,201 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Config API for experimental middleware between Transformer Engine and Kitchen.""" + +import dataclasses +import enum +import os +from typing import Optional + +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.experimental import quantization +from transformer_engine.pytorch.experimental import quantization_microblock_ref +from transformer_engine.pytorch.experimental.quantization import MMParams + + +@dataclasses.dataclass() +class QLinearParams: + """Quantization parameters of linear layer. + + Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors. + """ + + x_quantizer: Optional[quantization.ExperimentalQuantizer] = None + w_quantizer: Optional[quantization.ExperimentalQuantizer] = None + g_quantizer: Optional[quantization.ExperimentalQuantizer] = None + + mm_fprop: Optional[MMParams] = None + mm_dgrad: Optional[MMParams] = None + mm_wgrad: Optional[MMParams] = None + + +@enum.unique +class QuantizeRecipe(enum.Enum): + """Pre-defined quantization recipes for linear layers.""" + + NON_QUANTIZE = "non_quantize" + NVFP4_REF = "nvfp4_ref" + NVFP4_REF_RHT_ONLY = "nvfp4_ref_rht_only" + NVFP4_REF_2D_QUANTIZATION_ONLY = "nvfp4_ref_2d_quantization_only" + NVFP4_REF_RHT_AND_2D_QUANTIZATION = "nvfp4_ref_rht_and_2d_quantization" + + +def get_qlinear_params_from_predefined( + recipe: QuantizeRecipe, +) -> Optional[QLinearParams]: + """Get quantization parameters for linear layer based on recipe.""" + if recipe == QuantizeRecipe.NON_QUANTIZE: + return None + if recipe == QuantizeRecipe.NVFP4_REF: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_RHT_ONLY: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=False, + ), + ) + if recipe == QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION: + return QLinearParams( + x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ), + g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ), + ) + raise ValueError(f"Unsupported quantize recipe: {recipe}") + + +def get_qlinear_params_from_qat_params(qat_params_idx: int) -> Optional[QLinearParams]: + """Load quantization options from Kitchen to Transformer Engine. + + TODO(etsykunov): Confirm docstring is correct. + """ + assert qat_params_idx > 0, "QAT_PARAMS is not set." + + if qat_params_idx == 6010: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF) + if qat_params_idx == 960109: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_ONLY) + if qat_params_idx == 9002: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY) + if qat_params_idx == 9003: + return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION) + raise ValueError(f"Unsupported QAT params index: {qat_params_idx}") + + +def set_qlinear_params( + qlinear_params: Optional[QLinearParams] = None, + layer_number: Optional[int] = None, + layer_name: Optional[str] = None, +) -> Optional[QLinearParams]: + """Set quantization parameters based on configuration. + + Args: + qlinear_params: Quantization parameters. If None, loaded from environment. + layer_number: The numerical index of this layer in the model structure. + layer_name: The name for this layer. + + Returns: + QLinearParams: The finalized quantization parameters for this layer. + """ + if qlinear_params is None: + qat_params_idx = int(os.getenv("QAT_PARAMS", "0")) + if qat_params_idx == 0: + return None + return get_qlinear_params_from_qat_params(qat_params_idx) + + # Apply layer-specific overrides + if layer_number is not None: + raise NotImplementedError("Layer-specific overrides are not supported yet.") + if layer_name is not None: + raise NotImplementedError("Layer-specific overrides are not supported yet.") + + return qlinear_params + + +def get_experimental_quantizers(fp8: bool, qlinear_params: QLinearParams): + """Replacement of _get_quantizers() in TE modules.""" + if not fp8: + raise ValueError("FP8 is required to be enabled for experimental quantization.") + input_quantizer = qlinear_params.x_quantizer + weight_quantizer = qlinear_params.w_quantizer + output_quantizer = None + grad_input_quantizer = None + grad_weight_quantizer = None + grad_output_quantizer = qlinear_params.g_quantizer + + return ( + input_quantizer, + weight_quantizer, + output_quantizer, + grad_input_quantizer, + grad_weight_quantizer, + grad_output_quantizer, + ) diff --git a/transformer_engine/pytorch/experimental/gemm.py b/transformer_engine/pytorch/experimental/gemm.py new file mode 100644 index 000000000..d743b577b --- /dev/null +++ b/transformer_engine/pytorch/experimental/gemm.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GEMM API for experimental middleware between Transformer Engine and Kitchen.""" + +from typing import Iterable, Optional + +import torch + +from transformer_engine.pytorch.experimental.quantization import ( + MMParams, + GEMMType, + ExperimentalQuantizedTensor, +) +from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer + + +def experimental_gemm( + A: ExperimentalQuantizedTensor, + B: ExperimentalQuantizedTensor, + workspace: torch.Tensor, # pylint: disable=unused-argument + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument + gelu: bool = False, # pylint: disable=unused-argument + gelu_in: torch.Tensor = None, # pylint: disable=unused-argument + accumulate: bool = False, # pylint: disable=unused-argument + layout: str = "TN", + out: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + bias: Optional[torch.Tensor] = None, + use_split_accumulator: bool = False, + grad: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """Dispatch GEMM to quantizer's qgemm method.""" + assert isinstance(A, ExperimentalQuantizedTensor) and isinstance( + B, ExperimentalQuantizedTensor + ), "A and B must be ExperimentalQuantizedTensor instances" + + A, B = B, A + + # Determine GEMM type based on grad flag and layout + if not grad: + gemm_type = GEMMType.FPROP + else: + if layout == "NN": + gemm_type = GEMMType.DGRAD + elif layout == "NT": + gemm_type = GEMMType.WGRAD + else: + # Default to FPROP for other layouts + gemm_type = GEMMType.FPROP + + # Extract quantizer from QuantizedTensor to get qgemm logic + # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer? + quantizer = None + if hasattr(A, "quantizer") and A.quantizer is not None: + quantizer = A.quantizer + elif hasattr(B, "quantizer") and B.quantizer is not None: + quantizer = B.quantizer + else: + raise ValueError("No quantizer found in QuantizedETensor objects") + + # Create MMParams + m_params = MMParams( + out_dtype=out_dtype, + use_split_accumulator=use_split_accumulator, + ) + out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype + + if gemm_type == GEMMType.FPROP: + qx, sx = A.data, A.scale + qw, sw = B.data, B.scale + assert qx is not None + assert sx is not None + assert qw is not None + assert sw is not None + assert A.original_shape is not None + + # Call quantizer's qgemm method + result = quantizer.qgemm( + qx, + qw, + m_params, + out_dtype, + sx, + sw, + bias, + gemm_type=GEMMType.FPROP, + qresult_x=A, + qresult_w=B, + ) + if len(A.original_shape) > 2: + # Original input was 3D, so we need to reshape result back to 3D + batch_size = A.original_shape[0] + seq_len = A.original_shape[1] + result = result.view(batch_size, seq_len, result.shape[-1]) + elif gemm_type == GEMMType.DGRAD: + qdy, sdy = A.data, A.scale + qw_t, sw_t = B.data_t, B.scale_t + assert qdy is not None + assert sdy is not None + assert qw_t is not None + assert sw_t is not None + + result = quantizer.qgemm( + qdy, + qw_t, + m_params, + out_dtype, + sdy, + sw_t, + None, + gemm_type=GEMMType.DGRAD, + qresult_x=A, + qresult_w=B, + ) + elif gemm_type == GEMMType.WGRAD: + qdy_t, sdy_t = A.data_t, A.scale_t + qx_t, sx_t = B.data_t, B.scale_t + assert qdy_t is not None + assert sdy_t is not None + assert qx_t is not None + assert sx_t is not None + + result = quantizer.qgemm( + qdy_t, + qx_t, + m_params, + out_dtype, + sdy_t, + sx_t, + None, + gemm_type=GEMMType.WGRAD, + qresult_x=A, + qresult_w=B, + ) + + # Return in the same format as general_gemm + return result, None, None, None diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/experimental/quantization.py new file mode 100644 index 000000000..9adf4dabf --- /dev/null +++ b/transformer_engine/pytorch/experimental/quantization.py @@ -0,0 +1,203 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization API for experimental middleware between Transformer Engine and Kitchen.""" + +from __future__ import annotations +import abc +import dataclasses +import enum +from typing import Iterable, Optional, Tuple, Union + +import torch + +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer +from transformer_engine.pytorch.experimental import utils + + +@enum.unique +class GEMMType(enum.Enum): + """Type of GEMM operation being performed.""" + + FPROP = "fprop" + DGRAD = "dgrad" + WGRAD = "wgrad" + + +@dataclasses.dataclass(frozen=True) +class MMParams: + """Matrix multiplication parameters.""" + + out_dtype: torch.dtype | None = None + # Use split accumulator for more accurate FP8 GEMM + use_split_accumulator: bool = True + + +@dataclasses.dataclass +class ExperimentalQuantizedTensor(QuantizedTensorBase): + """Base class for experimental quantized tensor containers. + + An experimental container to hold quantization result, including quantized tensor, optional + transposed quantized tensor, and corresponding decoding scales. + + data: torch.Tensor + the quantized tensor. + scale: torch.Tensor + the decoding scale for the quantized tensor. Shape depends on the scaling granularity. + - if scaling type is PER_TENSOR, it should be a 1D scalar tensor. + data_t: torch.Tensor + the transposed quantized tensor (computed lazily if needed). + scale_t: torch.Tensor + the decoding scale for the transposed quantized tensor. + dtype: torch.dtype + nominal tensor datatype. + device: torch.device + device of the tensor. + quant_dtype: Union[utils.Fp4Formats, torch.dtype] + low precision tensor datatype. + original_shape: Tuple[int, ...] + original shape of the tensor. + quantizer: ExperimentalQuantizer + Builder class for quantized tensor. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + data_t: Optional[torch.Tensor] = None + scale_t: Optional[torch.Tensor] = None + global_amax_row: Optional[torch.Tensor] = None + global_amax_col: Optional[torch.Tensor] = None + + dtype: Optional[torch.dtype] = None + device: Optional[torch.device] = None + quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None + original_shape: Optional[Tuple[int, ...]] = None + quantizer: Optional[ExperimentalQuantizer] = None + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware.""" + return True + + def get_quantizer(self) -> ExperimentalQuantizer: + """Get builder for QuantizedExperimentalTensor + + Quantizer can be used for in-place operations. + + """ + if self.quantizer is not None: + return self.quantizer + raise ValueError("Quantizer is not set") + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], ExperimentalQuantizedTensor]: + """Prepare the quantization result for saving for backward""" + tensors = [self.data, self.data_t, self.scale, self.scale_t] + self.data = None + self.data_t = None + self.scale = None + self.scale_t = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the quantization result from the saved tensors""" + self.data = tensors[0] + self.data_t = tensors[1] + self.scale = tensors[2] + self.scale_t = tensors[3] + return tensors[4:] + + def dequantize(self, *args, **kwargs) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + # Compatibility + @property + def _data(self): + return self.data + + @_data.setter + def _data(self, value): + self.data = value + + @property + def _scale_inv(self): + return self.scale + + @_scale_inv.setter + def _scale_inv(self, value): + self.scale = value + + +class ExperimentalQuantizer(Quantizer): + """Experimental Quantizer class + + Defines the interface for experimental quantizers. + """ + + def __init__(self, *, rowwise: bool, columnwise: bool) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.internal = True + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware""" + return True + + @abc.abstractmethod + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: MMParams, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: GEMMType = GEMMType.FPROP, + qresult_x: ExperimentalQuantizedTensor | None = None, + qresult_w: ExperimentalQuantizedTensor | None = None, + ) -> torch.Tensor: + """Quantized GEMM interface.""" + + def dequantize(self, *args, **kwargs) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement dequantize function" + ) + + def update_quantized(self, *args, **kwargs) -> torch.Tensor: + """Update the quantized tensor with the given tensor in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_quantized function" + ) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + ) -> QuantizedTensorBase: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function" + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement calibrate function" + ) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_compatible_recipe function" + ) diff --git a/transformer_engine/pytorch/experimental/quantization_microblock_ref.py b/transformer_engine/pytorch/experimental/quantization_microblock_ref.py new file mode 100644 index 000000000..da749d237 --- /dev/null +++ b/transformer_engine/pytorch/experimental/quantization_microblock_ref.py @@ -0,0 +1,811 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen.""" + +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.experimental import quantization +from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.experimental.quantization import ( + ExperimentalQuantizedTensor, + ExperimentalQuantizer, +) + + +def cast_to_fp4x2(x): + """Quantize a tensor to FP4 E2M1 and store in a byte tensor""" + + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def cast_from_fp4x2(x, dq_dtype): + """Dequantize FP4 E2M1 tensor that has been represented in a byte tensor""" + fp4_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + device=x.device, + dtype=dq_dtype, + ) + + # Convert to long integers for indexing + second_bit = torch.div(x, 16, rounding_mode="floor").to(torch.long) + first_bit = (x - second_bit * 16).to(torch.long) + + # Use the long integers to index fp4_values + first_bit_values = fp4_values[first_bit] + second_bit_values = fp4_values[second_bit] + + result = torch.zeros( + (first_bit_values.shape[0], first_bit_values.shape[1] * 2), + device=x.device, + dtype=dq_dtype, + ) + result[:, ::2] = first_bit_values + result[:, 1::2] = second_bit_values + + return result + + +def cast_to_e8(decode_scale): + """Cast to a value that is representable in FP8 E8M0. + + The result is in FP32, not FP8 E8M0. + """ + max_exponent = torch.tensor(127, device=decode_scale.device, dtype=torch.float32) + exponent = torch.ceil(torch.log2(decode_scale)) + exponent = torch.clamp(exponent, min=-max_exponent, max=max_exponent) + + return torch.tensor(2.0, device=decode_scale.device, dtype=torch.float32) ** exponent + + +def cast_to_e4m3(decode_scale, global_amax): + """Scale and cast to FP8 E4M3. + + decode_scale is actually the encoding scaling factor. global_amax + can be any data tensor and not just the amax. + + TODO(etsykunov): Make less unintuitive. + """ + decode_scale = decode_scale * global_amax + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + return decode_scale.to(torch.float8_e4m3fn) + + +def high_precision_gemm_ref( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + accumulate: bool = False, + is_a_transposed: bool = False, + is_b_transposed: bool = False, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_alpha: float = 1.0, +) -> torch.Tensor: + """GEMM implementation with unquantized data""" + # Handle transpositions + mat1, mat2 = a, b + if is_a_transposed: + mat1 = a.T + if is_b_transposed: + mat2 = b.T + + # Ensure dtype compatibility for torch.addmm + mat1 = mat1.to(out_dtype) + mat2 = mat2.to(out_dtype) + + # Determine output shape + y_shape = (mat1.size(0), mat2.size(1)) + + if bias is not None: + assert not accumulate, "Bias is not supported with accumulation" + bias = bias.to(out_dtype) + # With bias case + if out_dtype == torch.float32: + y_ref = torch.addmm(bias.repeat(mat1.size(0), 1), mat1, mat2, beta=1, alpha=1) + else: + y_ref = torch.addmm(bias, mat1, mat2, beta=1, alpha=scale_alpha) + else: + # Without bias case + if accumulate and out is not None: + y_ref = out.clone().to(out_dtype) + else: + y_ref = torch.zeros(y_shape, dtype=out_dtype, device=a.device) + torch.addmm(y_ref, mat1, mat2, beta=1, alpha=scale_alpha, out=y_ref) + + return y_ref + + +class NVFP4TensorRef(ExperimentalQuantizedTensor): + """NVFP4 tensor for middleware between Transformer Engine and Kitchen""" + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"dtype={self.dtype}, " + f"device={self.device}, " + f"quant_dtype={self.quant_dtype}, " + f"data={self.dequantize(dtype=self.dtype)}, " + f"original_shape={self.original_shape}" + ")" + ) + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> ExperimentalQuantizedTensor: + """In-place update of quantized data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, ExperimentalQuantizedTensor): + return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) + self.get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from quantized tensor + """ + if dtype is None: + dtype = self.dtype + + # Ignore data_t for now + assert self.data is not None, "QuantizedTensor has no valid tensor data" + assert self.scale is not None, "QuantizedTensor has no valid scale" + tensor_data = self.data + tensor_scale = self.scale + # Dispatch to the quantizer + return self.get_quantizer().dequantize(tensor_data, tensor_scale, dtype=dtype) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """Generate or remove quantized data based on provided usage.""" + has_data = self.data is not None + has_data_transpose = self.data_t is not None + needs_data = has_data + needs_data_transpose = has_data_transpose + + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP8 data is required to generate FP8 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self.data = None + if not needs_data_transpose: + self.data_t = None + + def _create_transpose(self): + """Create transposed quantized tensor""" + if not self.data.is_contiguous(): + self.data = self.data.contiguous() + self.data_t = self.data.t().contiguous() + self.scale_t = self.scale + + def size(self, *args, **kwargs): # pylint: disable=unused-argument + """Return the original tensor shape, not the internal packed data shape. + + FP4 quantization packs two 4-bit values into each 8-bit value, which reduces + the second dimension by half. This method returns the logical shape that + users expect, not the internal packed storage shape. + """ + assert self.original_shape is not None + return torch.Size(self.original_shape) + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0], + dtype=torch.float32, + ) + + +class NVFP4QuantizerRef(ExperimentalQuantizer): + """NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" + + def __init__( + self, + dtype: utils.Fp4Formats, + rowwise: bool = True, + columnwise: bool = True, + pow_2_scales: bool = False, + eps: float = 0.0, + quant_tile_shape: Tuple[int, int] = (1, 16), + with_rht: bool = False, + with_random_sign_mask: bool = True, + ): + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = dtype + self.pow_2_scales = pow_2_scales + self.eps = eps + self.quant_tile_shape = quant_tile_shape + self.with_rht = with_rht + self.with_random_sign_mask = with_random_sign_mask + + @staticmethod + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: + """Apply randomized Hadamard transform without random signs (reference path). + + This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))). + """ + # Only apply when enabled + if not self.with_rht: + return x + + # RHT dimension equals the quantization tile length (NVFP4 uses 16) + rht_dim = self.quant_tile_shape[1] + assert ( + x.shape[-1] % rht_dim == 0 + ), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + + # Build H and scale + H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + @staticmethod + def _recover_swizzled_scales( + swizzled_scale: bool, scale: torch.Tensor, m: int, n: int, block_length: int + ) -> torch.Tensor: + if not swizzled_scale: + return scale + rounded_m = utils.roundup_div(m, 128) * 128 + scale_n = utils.roundup_div(n, block_length) + rounded_n = utils.roundup_div(scale_n, 4) * 4 + # Recover swizzled scaling factor layout -> linear layout + tmp = torch.reshape(scale, (rounded_m // 128, rounded_n // 4, 32, 4, 4)) + # after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4] + tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + result = torch.reshape(tmp, (rounded_m, rounded_n)) + return result[:m, :scale_n] + + @classmethod + def _quantize_blockwise_reference( + cls, + x: torch.Tensor, + global_amax: torch.Tensor, + tile_len_x: int, + tile_len_y: int, + *, + pow_2_scales: bool, + eps: float, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, torch.Tensor]: + + assert x.ndim == 2 + using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 + m, n = x.shape + # Compute vec_max based on the original x (before reshape) + # For 1D quantization: amax over each row chunk of 16 + # For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax + if using_2d_quantization: + # x shape: (128, 128) + x_blocks = ( + x.unfold(0, tile_len_y, tile_len_y) + .unfold(1, tile_len_x, tile_len_x) + .to(torch.float32) + ) # (8, 8, 16, 16) + block_amax = torch.amax(torch.abs(x_blocks), dim=(-1, -2)) # (8, 8) + # Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows + vec_max = block_amax.repeat_interleave(tile_len_y, dim=0).unsqueeze(-1) # (128, 8, 1) + else: + # x shape: (128, 128) + x_reshaped = x.view(m, n // tile_len_x, tile_len_x) # (128, 8, 16) + vec_max = torch.amax(torch.abs(x_reshaped), dim=-1, keepdim=True).to( + torch.float32 + ) # (128, 8, 1) + x = x.view(m, n // tile_len_x, tile_len_x) + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) + + if pow_2_scales: + decode_scale = cast_to_e8(decode_scale) + encode_scale = torch.div( + torch.tensor(1.0, device=x.device, dtype=torch.float32), + decode_scale.to(torch.float32), + ) + else: + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=global_encode_scale.device, + dtype=torch.float32, + ), + ) + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + decode_scale = decode_scale * global_encode_scale + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + + scaled_x = x.to(torch.float32) * encode_scale + + clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) + + return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + + @staticmethod + def _pad_tensor( + tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] + ) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = tensor.shape + padding_needed_rows = 0 + padding_needed_cols = 0 + + if row_divisor is not None and M % row_divisor != 0: + padding_needed_rows = row_divisor - (M % row_divisor) + # Check and calculate column padding if col_divisor is provided + if col_divisor is not None and N % col_divisor != 0: + padding_needed_cols = col_divisor - (N % col_divisor) + + # Return original tensor if no padding is needed + if padding_needed_rows == 0 and padding_needed_cols == 0: + return tensor + + # pad the tensor + out = torch.nn.functional.pad( + tensor, + (0, padding_needed_cols, 0, padding_needed_rows), + mode="constant", + value=0.0, + ).contiguous() + + return out + + @staticmethod + def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = original_size + out = tensor[:M, :N].contiguous() + return out + + def _quantize(self, tensor: torch.Tensor) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ]: + """ + Python implementation of microblock FP4 quantization. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to quantize (should be 2D) + + Returns + ------- + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor] + (qx, sx, qx_t, sx_t, global_amax) where: + - qx: quantized data in row-major order (if rowwise_usage), None otherwise + - sx: scale tensor for qx (if rowwise_usage), None otherwise + - qx_t: quantized data in column-major order (if columnwise_usage), None otherwise + - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise + - global_amax: global amax tensor + """ + if self.pow_2_scales: + assert self.quant_tile_shape == ( + 1, + 32, + ), "MXFP4 only supports 1x32 tile shape." + # TODO(etsykunov): Fix bug where global_amax_row and + # global_amax_col are not defined + # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) + else: + assert self.quant_tile_shape in ( + (1, 16), + (16, 16), + ), "NVFP4 only supports 1x16 or 16x16 tile shape." + # Prepare inputs once so we can reuse for both amax and quantization + # Row-input will always be the original input. + row_input = tensor + col_input = ( + self._apply_rht(tensor.t().contiguous()) + if self.with_rht + else tensor.t().contiguous() + ) + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) + + transpose_scales = False + + M, N = tensor.shape + if self.rowwise_usage: + x_input = row_input + x_padded = self._pad_tensor( + x_input, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx, sx = self._quantize_blockwise_reference( + x_padded, + global_amax_row, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + if transpose_scales: + sx = sx.T + + qx = self._rm_pad_tensor(qx, (M, N // 2)) + + else: + qx = None + sx = None + + if self.columnwise_usage: + x_t = col_input + x_t_padded = self._pad_tensor( + x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + + qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) + + if transpose_scales: + sx_t = sx_t.T + else: + qx_t = None + sx_t = None + + return qx, sx, qx_t, sx_t, global_amax_row, global_amax_col + + def quantize( + self, + tensor: torch.Tensor, + **kwargs, # pylint: disable=unused-argument + ) -> NVFP4TensorRef: + # sanity checks + assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype." + + # Make it work with 3D tensors + original_shape = tensor.shape + if tensor.ndim > 2: + tensor = tensor.view(-1, tensor.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(tensor) + + return NVFP4TensorRef( + data=qx, + scale=sx, + data_t=qx_t, + scale_t=sx_t, + global_amax_row=global_amax_row, + global_amax_col=global_amax_col, + dtype=tensor.dtype, + device=tensor.device, + quant_dtype=self.dtype, + quantizer=self, + original_shape=original_shape, + ) + + def update_quantized( + self, + src: torch.Tensor, + dst: ExperimentalQuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> ExperimentalQuantizedTensor: + """Update the quantized tensor with the given tensor in-place + + Parameters + ---------- + src: torch.Tensor + Source tensor to copy from + dst: ExperimentalQuantizedTensor + Destination ExperimentalQuantizedTensor to update + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + """ + # Handle noop flag + if noop_flag is not None and noop_flag.item() != 0: + return dst + + # Make sure input is in expected format + if not src.is_contiguous(): + src = src.contiguous() + + # Store the original shape and reshape for processing + original_shape = src.shape + if src.ndim > 2: + src = src.view(-1, src.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax = self._quantize(src) + + # Update the destination with new data + dst.data = qx + dst.scale = sx + dst.data_t = qx_t + dst.scale_t = sx_t + dst.global_amax = global_amax + dst.dtype = src.dtype + dst.quant_dtype = self.dtype + dst.original_shape = original_shape + + return dst + + @property + def supports_allgather_fp8(self) -> bool: + """Whether the tensor data can be all-gathered with an FP8 all-gather. + + TODO(etsykunov): Confirm docstring is correct. Also, this API + seems too FP8-specific and should be reconsidered. + """ + return False + + def transpose_qresult( + self, qresult: quantization.ExperimentalQuantizedTensor + ) -> quantization.ExperimentalQuantizedTensor: + """Convert row-wise data to column-wise data (?) + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Transpose qresult is not implemented for FP4.") + + @property + def supports_dequantize(self) -> bool: + """Whether quantized tensor can converted to high-precision tensor""" + return False + + @property + def is_data_t_transposed_in_memory(self) -> bool: + """Whether column-wise data is stored in transposed layout. + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Not implemented yet") + + def dequantize( + self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None + ) -> torch.Tensor: + """Dequantize the quantized tensor""" + raise NotImplementedError("Not implemented yet") + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: quantization.MMParams, + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, + qresult_x: quantization.ExperimentalQuantizedTensor | None = None, + qresult_w: quantization.ExperimentalQuantizedTensor | None = None, + ) -> torch.Tensor: + assert bias is None, "Bias is implemented for FP4 GEMM." + + high_precision_x = cast_from_fp4x2(qx, out_dtype) + high_precision_w = cast_from_fp4x2(qw, out_dtype) + + if self.pow_2_scales: + + if sx.dtype == torch.uint8: + # if scaling factor is stored in uint8 container + sx = torch.tensor(2.0, device=sx.device, dtype=torch.float32) ** ( + ( + sx.to(torch.float32) + - torch.tensor(127, device=sx.device, dtype=torch.float32) + ) + ) + sw = torch.tensor(2.0, device=sw.device, dtype=torch.float32) ** ( + ( + sw.to(torch.float32) + - torch.tensor(127, device=sw.device, dtype=torch.float32) + ) + ) + else: + # if scaling factor is torch.float8_e8m0fnu + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + alpha = torch.tensor(1.0, device=high_precision_x.device, dtype=torch.float32) + + else: + + assert qresult_x is not None + assert qresult_w is not None + + assert qresult_x.global_amax_row is not None + assert qresult_w.global_amax_col is not None + + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + factor = 6.0 * 6.0 * 448.0 * 448.0 + + if gemm_type == quantization.GEMMType.WGRAD: + partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col + else: + partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row + alpha = torch.div(partial_alpha, factor).squeeze(-1) + + M, K = high_precision_x.shape + N, K_w = high_precision_w.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + assert K % 32 == 0, "K dimension must be divisible by 32" + assert N % 8 == 0, "N dimension must be divisible by 8" + + block_length = 32 if self.pow_2_scales else 16 + + grid_k = K // block_length + + assert sx.shape == ( + M, + K // block_length, + ), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}" + assert sw.shape == ( + N, + K // block_length, + ), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}" + + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # below implementation is to match the FP4 tensor core implementation + # Each output element (i, j) is fp32 accumulation of (K // block_length) inner products + # Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32 + # Then batch the computation in M, N dimension + for k in range(grid_k): + k_start = k * block_length + k_end = k_start + block_length + + qx_block = high_precision_x[:, k_start:k_end].clone().contiguous() + qw_block = high_precision_w[:, k_start:k_end].clone().contiguous() + + # Extract scaling factors for the current blocks + sx_block = sx[:, k] + sw_block = sw[:, k] + + y += torch.outer(sx_block, sw_block) * high_precision_gemm_ref( + qx_block, qw_block, torch.float32, is_b_transposed=True + ) + + if not self.pow_2_scales and K > 0: + # only apply global scale for NVFP4 and non-empty cases + y = alpha * y + + # accumulation happens at epilogue in float32 + if accumulate: + assert out is not None, "Output tensor must be provided for accumulation." + y += out.to(torch.float32) + else: + assert out is None, "Output tensor should be None when accumulate is False." + + y = y.to(out_dtype) + return y diff --git a/transformer_engine/pytorch/experimental/utils.py b/transformer_engine/pytorch/experimental/utils.py new file mode 100644 index 000000000..20dc6f11b --- /dev/null +++ b/transformer_engine/pytorch/experimental/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utility functions for experimental middleware between Transformer Engine and Kitchen.""" + +import enum + +import torch + + +HIGH_PRECISION_FLOAT_DTYPES = ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, +) + + +class Fp4Formats(enum.Enum): + """FP4 data format""" + + E2M1 = "e2m1" + + +def roundup_div(x: int, y: int) -> int: + """Round up division""" + assert x >= 0 + assert y > 0 + return (x + y - 1) // y diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 8f9dbd88d..a75a03bfa 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -21,6 +21,7 @@ MXFP8BlockScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, ) from .constants import dist_group_type @@ -53,6 +54,13 @@ def check_mxfp8_support() -> Tuple[bool, str]: return False, "Device compute capability 10.0 or higher required for MXFP8 execution." +def check_nvfp4_support() -> Tuple[bool, str]: + """Return if nvfp4 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + + def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" if ( @@ -105,6 +113,13 @@ def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType return tex.DType.kFloat8E5M2 +def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: + """Get fp4 data type according to recipe and tensor""" + if fp4_recipe.fp4_format == Format.E2M1: + return tex.DType.kFloat4E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: """Get max representible FP8 value.""" if fp8_recipe.fp8_format == Format.E4M3 or ( @@ -142,6 +157,8 @@ class FP8GlobalStateManager: reason_for_no_mxfp8 = "" fp8_block_scaling_available = None reason_for_no_fp8_block_scaling = None + nvfp4_available = None + reason_for_no_nvfp4 = "" @classmethod def reset(cls) -> None: @@ -205,6 +222,13 @@ def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: ) return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling + @classmethod + def is_nvfp4_available(cls) -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + if cls.nvfp4_available is None: + cls.nvfp4_available, cls.reason_for_no_nvfp4 = check_nvfp4_support() + return cls.nvfp4_available, cls.reason_for_no_nvfp4 + @staticmethod def get_meta_tensor_key(forward: bool = True) -> str: """Returns scaling key in `fp8_meta`.""" @@ -481,6 +505,9 @@ def fp8_autocast_enter( if isinstance(fp8_recipe, Float8BlockScaling): fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() assert fp8_block_available, reason_for_no_fp8_block + if isinstance(fp8_recipe, NVFP4BlockScaling): + nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() + assert nvfp4_available, reason_for_no_nvfp4 @classmethod def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: @@ -837,6 +864,8 @@ def create( cls = Float8CurrentScalingRecipeState elif recipe.float8_block_scaling(): cls = Float8BlockScalingRecipeState + elif recipe.nvfp4(): + cls = NVFP4BlockScalingRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( @@ -1084,3 +1113,79 @@ def make_quantizers(self) -> list: ] ) ) + + +class NVFP4BlockScalingRecipeState(RecipeState): + """Configuration for NVFP4 quantization. + + NVFP4 quantization does not require state. + + """ + + recipe: NVFP4BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + from .tensor.nvfp4_tensor import NVFP4Quantizer + + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward. It assumes forward quantizers are + # ordered [input, weight, output, ...] and backward quantizers + # are ordered [grad_output, grad_input, ...]. This doesn't + # play nicely with fusible ops: Linear op doesn't own output + # or grad input quantizers, Quantize op only owns input and + # grad output quantizers. + + if self.mode == "forward": + + def _make_quantizer(idx: int) -> NVFP4Quantizer: + qparams = ( + self.recipe.fp4_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp4_quant_fwd_inp + ) + return NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, + stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + ) + for _ in range(self.num_quantizers) + ] + + raise RuntimeError(f"Unexpected recipe mode ({self.mode})") diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index e4fa0c741..3505a6830 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -4,16 +4,18 @@ """Internal function used by multiple modules.""" -from typing import Any, List, Optional, Tuple, Union, Callable -from dataclasses import dataclass - +import dataclasses import queue +from typing import Any, Callable, List, Optional, Tuple, Union + import torch from .. import cpp_extensions as tex +from .. import experimental from ..constants import TE_DType -from ..utils import get_default_init_method from ..export import is_in_onnx_export_mode +from ..tensor.utils import is_experimental +from ..utils import get_default_init_method def _get_normalization_func(normalization: str, forward: bool): @@ -170,7 +172,33 @@ def noop_cat( return _NoopCatFunc.apply(dim, *tensors) -@dataclass +def get_module_quantizers( + module: torch.nn.Module, + fp8_output: bool, + fp8_grad: bool, + debug: bool, +): + """Return the 6-tuple of quantizers for a module in a centralized way. + + Routing policy: + - If experimental quantization is enabled via environment and module.fp8 is True, + return experimental quantizers. + - Otherwise, return the module's own quantizers (debug or regular). + """ + if getattr(module, "fp8", False) and is_experimental(): + # TODO(etsykunov): Quantizer instantiation should be better + # done in the module's constructor + qlinear_params = experimental.config.set_qlinear_params() + + if qlinear_params is not None: + return experimental.config.get_experimental_quantizers(module.fp8, qlinear_params) + + if not debug: + return module._get_quantizers(fp8_output, fp8_grad) + return module._get_debug_quantizers(fp8_output, fp8_grad) + + +@dataclasses.dataclass class _ParameterInitMeta: """ Stores essential metadata needed to support deferred parameter initialization. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 70366dabe..bf4fb97d2 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -27,6 +27,7 @@ DelayedScalingRecipeState, Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -39,6 +40,7 @@ from ..constants import dist_group_type from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase @@ -76,7 +78,8 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - return 33_554_432 + # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales + return 32 * 1024 * 1024 + 256 return 4_194_304 @@ -757,6 +760,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8BlockScalingRecipeState ): return + if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -1218,15 +1223,13 @@ def grad_output_preprocess( ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - if isinstance(quantizer, Float8BlockQuantizer): + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance( - grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), - ): + if not isinstance(grad_output, QuantizedTensorBase): grad_output = quantizer(grad_output) return grad_output, grad_bias diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 4d30be414..6dbbd335e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -29,6 +30,7 @@ from ..fp8 import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, + assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -53,7 +55,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, WeightGradStore +from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers from ..tensor.quantized_tensor import ( QuantizedTensor, QuantizedTensorBase, @@ -135,6 +137,8 @@ def forward( if ub_name is not None: nvtx_label = f"{nvtx_label}.{ub_name}" + with_input_all_gather = parallel_mode == "column" and sequence_parallel + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -144,6 +148,7 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") @@ -157,7 +162,6 @@ def forward( weight_requires_grad = weight.requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad - with_input_all_gather = parallel_mode == "column" and sequence_parallel # Configure Userbuffers communication (comm+GEMM overlap) if debug: # turn off userbuffers in debug mode @@ -190,11 +194,13 @@ def forward( # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. + experimental = is_experimental(input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not experimental ) # Apply normalization @@ -240,7 +246,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = input_quantizer - if not with_quantized_norm: + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: ln_out = quantizer(ln_out) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather @@ -1422,6 +1429,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif other recipes (mxfp8, etc) def reset_layer_norm_parameters(self) -> None: @@ -1526,11 +1535,7 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad) - ) + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1763,6 +1768,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 9f799c553..a0e5f3aed 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -17,6 +17,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_experimental from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -40,6 +41,7 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -64,6 +66,7 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload @@ -114,7 +117,8 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling: [] - if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + # TODO(ksivaman): Fuse nvfp4 act once kernel is available. + if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4(): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), @@ -211,6 +215,7 @@ def forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -258,11 +263,13 @@ def forward( # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm + experimental = is_experimental(fc1_input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not experimental ) # Apply normalization @@ -302,7 +309,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer - if not with_quantized_norm: + # experimental recipe doesn't need to support quantized AG + if not with_quantized_norm and not experimental: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -548,6 +556,7 @@ def forward( if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, @@ -673,6 +682,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -1014,7 +1024,10 @@ def fc2_wgrad_gemm( if ctx.fp8: # TODO float8 blockwise current scaling has no bgrad fusion for now - if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): + # TODO(ksivaman): Re-add fusion once kernel is available. + if isinstance( + ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) else: @@ -1690,6 +1703,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_layer_norm_parameters(self) -> None: @@ -1908,7 +1923,10 @@ def _get_quantizers(self, fp8_output): fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, - columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), + columnwise=isinstance( + fc2_input_quantizer, + (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), + ), ) fc1_input_quantizer.internal = True if fp8_output: @@ -2113,6 +2131,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.set_parallel_mode: + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].amax_reduction_group = self.tp_group + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: """Get the weight tensors of the module.""" return [self.fc1_weight, self.fc2_weight] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 7e526245c..cf7f58947 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,7 +25,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, WeightGradStore +from ._common import noop_cat, WeightGradStore, get_module_quantizers from ..fp8 import FP8GlobalStateManager from ..utils import ( cast_if_needed, @@ -35,6 +35,7 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, ) @@ -65,6 +66,7 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.utils import is_experimental from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -151,6 +153,9 @@ def forward( ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG + # experimental recipe check + experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) + # ------------------------------------------------------ # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -161,6 +166,7 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer @@ -172,7 +178,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase): + if not isinstance(inputmat, QuantizedTensorBase) and not experimental: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -442,6 +448,7 @@ def forward( ctx.main_grad_func = lambda: weight.main_grad ctx.debug = debug + ctx.experimental = experimental ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -609,7 +616,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat, QuantizedTensorBase): # Input tensor is already quantized pass - elif ctx.debug: + elif ctx.debug or ctx.experimental: # Debug quantizer will be applied immediately before wgrad GEMM pass else: @@ -698,6 +705,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # dgrad GEMM # Note: dx = dy * w + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weight_fp8, @@ -1326,6 +1334,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_parameters(self, defer_init=False): @@ -1410,12 +1420,7 @@ def forward( weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = ( - self._get_quantizers(fp8_output, fp8_grad) - if not debug - else self._get_debug_quantizers(fp8_output, fp8_grad) - ) - + quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1655,6 +1660,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 70c70c54d..f8f95cf19 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -926,6 +926,7 @@ def op_forward( input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) weight_quantizer.set_usage(rowwise=True, columnwise=False) + # Recipe-specific configuration recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_current_scaling(): input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale @@ -940,6 +941,13 @@ def op_forward( if self.sequence_parallel and self.tensor_parallel_mode == "row": grad_output_quantizer.with_amax_reduction = True grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + if recipe.nvfp4(): + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Get autocast dtype if needed if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 7fa12cc08..43846512d 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -54,6 +54,7 @@ def get_all_tensor_types(): Float8BlockwiseQTensor, Float8BlockwiseQTensorBase, ) + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase all_tensor_types = [ torch.Tensor, @@ -64,5 +65,7 @@ def get_all_tensor_types(): MXFP8TensorBase, Float8BlockwiseQTensor, Float8BlockwiseQTensorBase, + NVFP4Tensor, + NVFP4TensorBase, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py new file mode 100644 index 000000000..df187d674 --- /dev/null +++ b/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py @@ -0,0 +1,348 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for NVFP4Tensor""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import math +from typing import Any, Dict, Optional, Tuple, Union +import warnings + +import torch + +# import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ..quantized_tensor import QuantizedTensorBase + +# from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ..quantized_tensor import Quantizer +from ...utils import _empty_tensor + + +@functools.lru_cache(maxsize=None) +def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Values representable in FP4 E2M1 format""" + return torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + device=device, + dtype=dtype, + ) + + +class _FromNVFP4Func(torch.autograd.Function): + """Cast from NVFP4 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: NVFP4TensorBase, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + + # Dequantize row-wise data + if tensor._rowwise_data is not None: + ### TODO(tmoon): Debug dequantize kernel and remove unfused impl + # return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) + + # Tensor properties + shape = list(tensor._rowwise_data.size()) + shape[-1] *= 2 + device = tensor._rowwise_data.device + + # Convert FP4E2M1 values to FP32 + data = tensor._rowwise_data.view(torch.uint8).to(torch.int32) + data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape) + data = _fp4_e2m1_vals(device, dtype=torch.float32)[data] + data = data.to(torch.float32).contiguous() + + # Convert FP8E4M3 block scales to FP32 + block_scales = tensor._rowwise_scale_inv + block_scales = block_scales.reshape(-1, block_scales.size(-1)) + block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16] + block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32) + + # Convert amax to FP32 tensor scale + tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max + + # Apply scales + block_data = data.view(-1, 16) + block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1) + + return data.to(dtype) + + if tensor._columnwise_data is not None: + raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") + raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class NVFP4TensorBase(QuantizedTensorBase): + """Mixin class that holds data attributes of NVFP4Tensor. + + NVFP4Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + _fp4_dtype: TE_DType + _amax_rowwise: torch.Tensor + _amax_columnwise: torch.Tensor + + def __new__( + cls, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + quantizer: Optional[Quantizer], + *args, + **kwargs, + ): + + instance = super().__new__(cls, *args, **kwargs) + + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._fp4_dtype = fp4_dtype + instance._quantizer = quantizer.copy() if quantizer is not None else None + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._amax_rowwise = amax_rowwise + instance._amax_columnwise = amax_columnwise + + return instance + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + for t in ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ): + if t is not None: + t.data = _empty_tensor() + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "amax_rowwise": self._amax_rowwise, + "amax_columnwise": self._amax_columnwise, + "fp4_dtype": self._fp4_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]: + """Prepare the tensor base for saving for backward""" + tensors = [ + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ] + self._rowwise_data = None + self._columnwise_data = None + self._rowwise_scale_inv = None + self._columnwise_scale_inv = None + self._amax_rowwise = None + self._amax_columnwise = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + self._rowwise_scale_inv = tensors[2] + self._columnwise_scale_inv = tensors[3] + self._amax_rowwise = tensors[4] + self._amax_columnwise = tensors[5] + return tensors[6:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromNVFP4Func.forward(None, self, dtype) + + def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: + # pylint: disable=missing-function-docstring + + # Infer tensor shape + shape = None + if self._rowwise_data is not None: + byte_shape = list(self._rowwise_data.size()) + shape = byte_shape[:-1] + [byte_shape[-1] * 2] + elif self._columnwise_data is not None: + warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.") + byte_shape = list(self._columnwise_data.size()) + shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]] + if shape is None: + raise RuntimeError("Attempted to get shape of NVFP4 tensor with no data") + + # Return shape or dim + if dim is None: + return torch.Size(shape) + return shape[dim] + + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + + # Return input tensor if view not needed + cur_shape = self.size() + if shape is None or shape == cur_shape: + return self + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if self._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = self._rowwise_data.view(byte_shape) + if self._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = self._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4TensorBase( + rowwise_data=new_rowwise_data, + rowwise_scale_inv=self._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=self._columnwise_scale_inv, + amax_rowwise=self._amax_rowwise, + amax_columnwise=self._amax_columnwise, + quantizer=self._quantizer, + fp4_dtype=self._fp4_dtype, + ) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorBase(" + f"rowwise_scaled_data={data_rowwise}," + f"rowwise_scale_inv={self._rowwise_scale_inv}," + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + For the NVFP4 format, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses" + ) + if self._amax_rowwise is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing per tensor" + " row-scaled scale-inverse" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + self._amax_rowwise = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing column-scaled scale-inverses" + ) + if self._amax_columnwise is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing per tensor column-scaled scale-inverse" + ) + else: + self._columnwise_data = None + self._columnwise_scale_inv = None + self._amax_columnwise = None diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 321c351dd..d7f5f8c7d 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Tensor class with FP8 data""" +"""Tensor class with MXFP8 data""" from __future__ import annotations from collections.abc import Iterable import math @@ -186,8 +186,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. + precision. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py new file mode 100644 index 000000000..b12e89956 --- /dev/null +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -0,0 +1,898 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with NVFP4 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Optional, Tuple, Union +import functools + +import torch +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe +from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..utils import ( + canonicalize_process_group, + devices_match, + round_up_to_nearest_multiple, +) + +from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func +from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc + +aten = torch.ops.aten + + +def get_no_random_sign_vector() -> torch.Tensor: + """Non-random sign vector for Hadamard transform.""" + return torch.tensor([1], dtype=torch.float32) + + +def get_sign_from_vector(vector: torch.Tensor) -> int: + """Convert sign vector to bitmask. + + Used for random Hadamard transform. + + """ + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded random signs for Hadamard transform. + + https://xkcd.com/221/ + + """ + return torch.tensor( + [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], + dtype=torch.float32, + ) + + +def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor: + """Construct a 16x16 Hadamard matrix.""" + assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported." + hadamard_scale = 1 / math.sqrt(hadamard_dimension) + return ( + torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1], + [1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1], + [1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1], + [1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1], + [1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1], + [1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1], + [1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1], + [1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1], + [1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1], + [1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1], + [1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1], + [1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1], + [1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1], + [1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1], + ], + dtype=torch.float32, + ) + * hadamard_scale + ) + + +@functools.lru_cache(maxsize=None) +def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor: + """Construct matrix used in random Hadamard transform.""" + hadamard_dimension = 16 + if with_random_sign_mask: + signs = get_wgrad_sign_vector() + else: + signs = get_no_random_sign_vector() + sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32) + rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension) + return rht_matrix.to(dtype=torch.bfloat16).cuda() + + +@functools.lru_cache(maxsize=None) +def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int: + """Sign mask for random Hadamard transform.""" + if with_random_sign_mask: + return get_sign_from_vector(get_wgrad_sign_vector()) + return 0 + + +class NVFP4Quantizer(Quantizer): + """Builder class for NVFP4 tensors with NV block scaling""" + + dtype: TE_DType + """Random Hadamard Transform""" + with_rht: bool + with_post_rht_amax: bool + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + + """2D block scaling, only applicable for weights.""" + with_2d_quantization: bool + + """Stochastic rounding, only applicable for gradients.""" + stochastic_rounding: bool + + """RHT matrix random sign mask""" + rht_matrix_random_sign_mask_t: int + rht_matrix: torch.Tensor + + def __init__( + self, + fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + with_rht: bool = False, + with_post_rht_amax: bool = False, + with_2d_quantization: bool = False, + stochastic_rounding: bool = False, + with_random_sign_mask: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp4_dtype + self.with_rht = with_rht + self.with_post_rht_amax = with_post_rht_amax + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.with_2d_quantization = with_2d_quantization + self.stochastic_rounding = stochastic_rounding + self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask) + self.rht_matrix = get_rht_matrix(with_random_sign_mask) + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + return dst + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + return True + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For NVFP4 1D blockwise quantization, blocksize is 16 + - If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4)) + - If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4)) + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = round_up_to_nearest_multiple(K, 128) + inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + # rowwise + outer = round_up_to_nearest_multiple(M, 128) + inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + + @staticmethod + def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of a tensor after columnwise quantization. + + For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling. + + Parameters + ---------- + shape : Iterable[int] + Original shape of the tensor + + Returns + ------- + Tuple[int, ...] + New shape with dimensions rearranged for columnwise layout. + For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). + Returns empty tuple for empty input shape. + """ + if len(shape) == 0: + return tuple() + # and then after AG, a reorganize kernel will be called to restore the shape + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + @staticmethod + def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: + """Convert shape for FP4 data by dividing the last dimension by 2""" + shape = list(shape) + shape[-1] = shape[-1] // 2 + return tuple(shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + requires_grad: bool = False, + ) -> NVFP4Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + flat_first_dim = math.prod(shape[:-1]) + assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP4 data + data = None + scale_inv = None + amax_rowwise = None + if self.rowwise_usage: + data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device) + # Allocate per tensor scale inverse. FP32 format. + amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + amax_columnwise = None + if self.columnwise_usage: + # enforce 2D shape to avoid [S, B, H] shape and B and be 1 + # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + shape_2d = tuple([flat_first_dim, shape[-1]]) + columnwise_data = torch.empty( + self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), + dtype=torch.uint8, + device=device, + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, dtype=torch.uint8, device=device + ) + amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device) + + # Construct FP8 tensor + return NVFP4Tensor( + shape=shape, + dtype=dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.dtype, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + pass # Calibration is no-op + + def _canonicalized_amax_reduction_group(self) -> dist_group_type: + """Get process group for amax reduction""" + return canonicalize_process_group(self.amax_reduction_group) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return NVFP4BlockScaling + + +class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): + """Quantized tensor class with FP4 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP4. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + Raw FP4 data in a uint8 tensor (rowwise layout). + rowwise_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP4, i.e. the scaling factor that must + be applied when casting from FP4 to higher + precision (rowwise). + columnwise_data: torch.Tensor, optional + Raw FP4 data in a uint8 tensor (columnwise layout). + columnwise_scale_inv: torch.Tensor, optional + Reciprocal of the scaling factor for columnwise FP4 data. + amax_rowwise: torch.Tensor, optional + Rowwise amax tracking tensor. + amax_columnwise: torch.Tensor, optional + Columnwise amax tracking tensor. + fp4_dtype: TE_DType + The FP4 data type used for quantization. + quantizer: Quantizer + The quantizer instance used for this tensor. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype, used in dequantize. + """ + + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + fp4_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + fp4_dtype, + quantizer, + *args, + **kwargs, + ) + return instance + + def __repr__(self, *, tensor_contents=None): + return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from NVFP4Tensor + + By default the resulting tensor's dtype is the + NVFP4Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromNVFP4Func.apply(self, dtype) + return _FromNVFP4Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return NVFP4Quantizer() + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> NVFP4Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return NVFP4Tensor.make_like(self) + + def clone(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> NVFP4Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("NVFP4Tensor does not support different memory formats!") + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + if len(args) != 2: + raise RuntimeError("Unexpected args for view op (expected 2 args, got {len(args)})") + tensor = args[0] + shape = args[1] + if shape == list(tensor.size()): + return tensor.detach() + return tensor.view(shape) + + # NVFP4 dequantize not supported. Add manual support for needed funcs. + if func in (aten.empty_like.default, aten.zero_.default): + tensor = args[0] + data_init_func = torch.zeros_like if func == aten.zero_.default else torch.empty_like + scale_inv_init_func = ( + torch.ones_like if func == aten.zero_.default else torch.empty_like + ) + + if tensor._rowwise_data is not None: + rowwise_data = data_init_func(tensor._rowwise_data) + rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv) + amax_rowwise = torch.zeros_like(tensor._amax_rowwise) + else: + rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None + + if tensor._columnwise_data is not None: + columnwise_data = data_init_func(tensor._columnwise_data) + columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv) + amax_columnwise = torch.zeros_like(tensor._amax_columnwise) + else: + columnwise_data, columnwise_scale_inv, amax_columnwise = ( + None, + None, + None, + ) + + return NVFP4Tensor( + shape=tensor.shape, + dtype=tensor.dtype, + fp4_dtype=tensor._fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=tensor._quantizer, + requires_grad=tensor.requires_grad, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + ) -> NVFP4Tensor: + """Build NVFP4Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + NVFP4Tensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + self._fp4_dtype, + self.dtype, + self._quantizer, + ), + ) + + def _get_data(self) -> NVFP4Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a NVFP4Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy FP8 data if other tensor is NVFP4Tensor + if isinstance(tensor, NVFP4Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + NVFP4Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + self._amax_rowwise = tensor._amax_rowwise + self._amax_columnwise = tensor._amax_columnwise + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.update_quantized(tensor, self) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting NVFP4Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.view(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.view(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.view(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.reshape(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.reshape(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.reshape(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.reshape(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 656eda46c..7b88d2519 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -264,6 +264,10 @@ def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False + def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument + """Returns whether or not given tensor can be quantized""" + return True + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 23f56da5d..a4bdf5e07 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -4,11 +4,13 @@ """Helper functions for using fp8 tensors as weights""" +import os +from typing import Optional, Union import torch import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor +from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer @@ -450,3 +452,20 @@ def _cast_master_weights_to_fp8_blockwise_scaling( tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) + + +def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool: + """Check if an environment or object is using experimental Kitchen middleware. + + Returns False if x is a torch.Tensor. + """ + # Detect if the environment is experimental + if x is None: + return int(os.getenv("QAT_PARAMS", "0")) > 0 + + # Detect if the object is experimental + if isinstance(x, torch.Tensor): + return False + if not isinstance(x, (Quantizer, QuantizedTensorBase)): + raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance") + return hasattr(x, "experimental") and x.experimental diff --git a/transformer_engine/pytorch/triton/pad.py b/transformer_engine/pytorch/triton/pad.py new file mode 100644 index 000000000..29b0daf31 --- /dev/null +++ b/transformer_engine/pytorch/triton/pad.py @@ -0,0 +1,94 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 padding kernels + +TODO(ksivamani): Documentation + +""" + +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1), + ], + key=["out_dim0", "out_dim1"], +) +@triton.jit +def zero_pad_kernel( + inp_ptr, + out_ptr, + in_dim0: tl.constexpr, + in_dim1: tl.constexpr, + out_dim0: tl.constexpr, + out_dim1: tl.constexpr, + in_s0, + in_s1, + out_s0, + out_s1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + # tile over OUTPUT coordinates + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols + om = offs_m[:, None] + on = offs_n[None, :] + + # edge masking for output + out_mask = (om < out_dim0) & (on < out_dim1) + + # valid input region is simply top-left (no offsets) + in_mask = (om < in_dim0) & (on < in_dim1) + + # load valid input, else zero (masked load touches memory only where True) + x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0) + + # store to output (only within bounds of the output tile) + tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask) + + +def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor: + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + assert inp.ndim == 2 + dim0, dim1 = inp.shape + + pad_x = (128 - dim0 % 128) % 128 + pad_y = (4 - dim1 % 4) % 4 + out_x = dim0 + pad_x + out_y = dim1 + pad_y + out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype) + + in_s0, in_s1 = inp.stride() + out_s0, out_s1 = out.stride() + + BLOCK_M, BLOCK_N = 128, 128 + grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N)) + + zero_pad_kernel[grid]( + inp, + out, + dim0, + dim1, + out_x, + out_y, + in_s0, + in_s1, + out_s0, + out_s1, + ) + return out diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 6420f3e12..1a0722f89 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -11,8 +11,8 @@ import numpy as np import torch -import transformer_engine.pytorch.cpp_extensions as ext from . import torch_version +from .tensor.quantized_tensor import Quantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor @@ -441,6 +441,16 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ) +def assert_dim_for_all_gather( + tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer +) -> None: + """Assert that tensor dimensions are supported for all-gather""" + if with_all_gather: + assert quantizer.is_quantizable(tensor), ( + "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ + ) + + def is_bf16_compatible() -> None: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher. @@ -460,6 +470,8 @@ def is_non_tn_fp8_gemm_supported() -> bool: @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" + import transformer_engine.pytorch.cpp_extensions as ext + encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) From 2354fb8b02ec64e73f82f2fd564f541c29a5e737 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Tue, 30 Sep 2025 09:01:24 -0700 Subject: [PATCH 014/141] Fix the segfault in the nvfp4 quantization (#2214) * Fix the segfault in the nvfp4 quantization Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/util/nvfp4_transpose.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh index fe9736298..712b557c5 100644 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -1433,7 +1433,8 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o const size_t block_size = THREADS_NUM; const size_t scale_stride = output->scale_inv.shape[1]; - const size_t scale_stride_transpose = output->columnwise_scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); nvfp4_scale_t *const scales_transpose_ptr = From 25252e9f2bc1460a841f32ef172126fb9192515a Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Tue, 30 Sep 2025 15:17:37 -0700 Subject: [PATCH 015/141] [PyTorch] Add FP8 attention with current scaling (#2012) * debug existing usage Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_dpa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reimplement fp8_dpa Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * redesign CS; need cleanup Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up s/dP quantizers Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return dP to DS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * improve quantizer_helper; tweak dP DS/CS logic Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * debug CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up non-CP; debug dq/dk mismatches Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor success with CP; need to remove debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove debug info Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable fp8 output for fp8_mha + CS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add output_tensor_type to FADescriptor Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove print Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more fixes for non-CP and CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable non-determinism for blackwell Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix indent; remove print Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * switch from create_tensor_from_data to make_like Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable a2a+p2p for CS CP and require additional cp_group_global Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * condense tests; only create dist groups once Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * consolidate CP P2P per-tile calls for fwd/bwd and fused/flash Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix flash-attn from last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes for previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix attn_mask_type in f16 causal Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert bb6a0a59 temporarily Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reenable comparison for some tensors in CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dbias for fused attn CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clean up prints/comments and add back NVTE_CS_dP_SCALE Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * first attempt at mixed DS/CS reduction Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for last commit for mixed DS/CS reduction Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints from 69639024 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix DS recipe for dP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_DPA_FORCE_DS to force DS for all DPA tensors, not just dP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix NVTE_DPA_FORCE_DS and add NVTE_PRINT Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * modify DS recipe for MLPerf Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce only over TP group; need to think about CP group later Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * streamline fake_recipe/quantizer generation; allow NVTE_DPA_Fixed_Scales or DS-update S/dP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more print: NVTE_LAYER_NUMBER Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * split S/dP in env vars: NVTE_DPA_Fix_S_Scale and NVTE_DPA_Fix_dP_Scale Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix autocast_key for DS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_REPEAT_in_F16 to repeat FP8 fwd/bwd passes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add FP8 CS to UnfusedDPA; unsuccessful; does not affect other backends Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * temporary: print min/max and save tensors for debugging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * emulate q/dq+bf16 with NVTE_Emulate_in_F16; add NVTE_DPA_FORCE_MXFP8 for MXFP8 q/dq Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add RHT to BMM1 with NVTE_RHT_BMM1 for the size Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * re-enable fused attn in dpa_fp8_vs_f16 test; changed during unfused attn implementation Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add NVTE_FP8_CS_POWER_OF_2, NVTE_DPA_FORCE_BLOCKFP8, NVTE_Emulate_QDQ_QKV, NVTE_Emulate_QDQ_O, NVTE_Emulate_QDQ_dO Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add F16 O support for FP8 kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to TE FE commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * return to FE develop Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * tidy up; untested Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes and improvements for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more minor fixes and improvements Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * more small fixes/improvements; mostly for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CS/DS recipe switch in DPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * avoid quantizing/saving of O when CS bwd uses F16 O Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move fp8_autocast(fp8_recipe) print to utils.py Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add debug logging to unit tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add back prints of quantizers/layer_number for debugging Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * enable amax reduction for both CS and DS tensors Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix NVTE_FP8_DPA_BWD=0 for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit for F16 fwd/bwd a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * small fixes for float8_current_scaling(), nominal types, and unruly d_out types Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8_output in MHA and some CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes to CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes for CP A2A Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * clamp input data in tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove rmse and tighten atol/rtol for tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restructure fp8_recipes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linter Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "remove rmse and tighten atol/rtol for tests" This reverts commit 15dba6a59a5323d414f02cf22f099cb00d880532. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more fixes for linter Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 recipe changes for F16 code path Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to FE on main to help with merges Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * switch back to FE develop after merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE develop commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to GitHub FE 1.14.1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * update FE to its latest main Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for A2A Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix last commit for A2A DS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove memset for BSHD/SBHD FP8 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove concat for qkv quantization in CS Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * improve/simplify the logic for last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add nominal_type for UnfusedDPA FP8 EmuFunc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: update env vars for DPA recipes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo in last commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix DS recipe creation for NVFP4 global recipe Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace python max with torch.maximum Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linter Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CP A2A for FA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce prints in print_quantizers Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add FP8 env vars to NVTE_DEBUG prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add reduce_amax to DS repr Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * separate fp8_dpa/fp8_mha in CP tests; fix A2A for them; add f16_O tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address some reciews Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make data optional in create_hp_tensor_with_amax Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix for comments in bwd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * print cudnn version in attn tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable CS for Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * alternative tests to reduce CI time Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make NVTE_DPA_FP8CS_O_in_F16 default to 1 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove _fp8 variables to avoid confusion Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * return to requiring two cp_groups for a2a+p2p Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * replace NVTE_PRINT with NVTE_DEBUG/_LEVEL for quantizer prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * provide a basic set of tests for CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the last merge with nvfp4 PR Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * disable for Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fp8 backend selection for Hopper Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * reduce CP CI to essential tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix to CP test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix recipe logic in tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert to concat for qkv quantization Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove cudnn version in qa scripts Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../attention/run_attention_with_cp.py | 151 +- tests/pytorch/attention/test_attention.py | 141 +- .../attention/test_attention_with_cp.py | 91 +- .../fused_attn_f16_arbitrary_seqlen.cu | 8 +- .../common/fused_attn/fused_attn_fp8.cu | 147 +- transformer_engine/common/fused_attn/utils.h | 13 +- transformer_engine/common/recipe/__init__.py | 11 +- .../dot_product_attention/backends.py | 498 ++- .../dot_product_attention/context_parallel.py | 3283 ++++++++--------- .../dot_product_attention.py | 319 +- .../attention/dot_product_attention/utils.py | 221 +- .../pytorch/attention/multi_head_attention.py | 46 +- .../pytorch/cpp_extensions/fused_attn.py | 3 - transformer_engine/pytorch/csrc/common.h | 2 +- transformer_engine/pytorch/csrc/extensions.h | 5 + .../pytorch/csrc/extensions/attention.cpp | 195 +- .../pytorch/csrc/extensions/cast.cpp | 20 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- transformer_engine/pytorch/fp8.py | 4 +- .../pytorch/tensor/float8_tensor.py | 4 + 21 files changed, 2970 insertions(+), 2202 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 1a7b4b78d..80a8e4af4 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 1a7b4b78db44712fb9707d21cd2e3179f1fd88b8 +Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 7e47e7df8..d490c235b 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -12,14 +12,18 @@ from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) +from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer -from transformer_engine.common.recipe import DelayedScaling +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert - dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -151,7 +155,7 @@ def get_tols(config, dtype): elif dtype == "fp8": atol = 5e-1 rtol = 5e-1 - rmse_tol = 0.1 + rmse_tol = 0.15 else: assert False, f"{dtype=} is not supported!" @@ -164,14 +168,23 @@ def run_dpa_with_cp( qkv_format="bshd", kernel_backend="FlashAttention", cp_comm_type="p2p", - fp8_mha=False, + fp8_bwd="True", + fp8_dpa="False", + fp8_mha="False", + scaling_mode="delayed", + f16_O="False", log_level=logging.WARNING, ): """Test DotProductAttention module with context parallelism""" logging.root.setLevel(log_level) # set up environment variables and config - fp8_mha = fp8_mha == "True" + fp8_bwd = fp8_bwd == "True" and dtype == "fp8" + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" + fp8_dpa = fp8_dpa == "True" and dtype == "fp8" + fp8_mha = fp8_mha == "True" and dtype == "fp8" + f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" if kernel_backend == "FlashAttention": @@ -219,8 +232,12 @@ def run_dpa_with_cp( sub_group = dist.new_group(sub_ranks, backend="nccl") if rank in sub_ranks: cp_comm_sub_groups.append(sub_group) + if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) # instantiate attention module core_attn = DotProductAttention( @@ -247,19 +264,38 @@ def run_dpa_with_cp( cu_seqlens_q_padded, cu_seqlens_kv_padded, ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) - q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() - k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() - v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() - for x in [q, k, v]: - x.requires_grad = True - - dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() - if fp8_mha: + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() + if scaling_mode == "delayed": + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) dout_quantizer = Float8Quantizer( fp8_dtype=tex.DType.kFloat8E5M2, scale=torch.tensor([1], dtype=torch.float32).cuda(), amax=torch.tensor([0], dtype=torch.float32).cuda(), ) + if scaling_mode == "current": + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) @@ -274,6 +310,7 @@ def run_dpa_with_cp( else: fp8_context = nullcontext() with fp8_context: + # q, k, v, out in FP8; dout in F16 out = core_attn( q, k, @@ -284,8 +321,9 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - if fp8_mha: + if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: @@ -298,24 +336,10 @@ def run_dpa_with_cp( ############ run with CP ############ logging.info(f"[Rank {rank}] Run with context parallelism") - # set up environment - core_attn.set_context_parallel_group( - cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, - cp_comm_ranks, - torch.cuda.Stream(), - cp_comm_type, - ) - if config.softmax_type != "vanilla": - core_attn.softmax_offset.grad.zero_() - if dtype == "fp8": - core_attn.reset_fp8_meta_tensors() - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) - else: - fp8_context = nullcontext() - # set up inputs q_, k_, v_, dout_, *rest = [ - x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) ] bias_ = rest[0] if len(rest) else None if qkv_format == "bshd" or qkv_format == "sbhd": @@ -343,6 +367,16 @@ def run_dpa_with_cp( ) q_, dout_ = [x.index_select(0, seq_idx_q) for x in [q_, dout_]] k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] + else: + assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( @@ -350,9 +384,25 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + # set up environment + core_attn.set_context_parallel_group( + cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, + cp_comm_ranks, + torch.cuda.Stream(), + cp_comm_type, + ) + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() + if dtype == "fp8": + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + else: + fp8_context = nullcontext() # run attention with fp8_context: + # q, k, v, out in FP8; dout in F16 out_ = core_attn( q_, k_, @@ -363,27 +413,30 @@ def run_dpa_with_cp( cu_seqlens_kv=cu_seqlens_kv, cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_padded, + fp8_output=fp8_mha, ) - if fp8_mha: + if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) - if fp8_mha: - assert isinstance(out, Float8Tensor) - assert isinstance(out_, Float8Tensor) - out = out.dequantize() - out_ = out_.dequantize() - - # get outputs dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad d_softmax_offset_ = None if config.softmax_type != "vanilla": d_softmax_offset_ = core_attn.softmax_offset.grad.clone() - for x in [out_, dq_, dk_, dv_, d_softmax_offset_]: - if x is not None: - assert torch.all(~torch.isnan(x)) - assert torch.all(~torch.isinf(x)) + + # get outputs + tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] + if fp8_mha: + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[4] = tensors_to_deq + for tensor in tensors: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) + out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": @@ -394,17 +447,17 @@ def run_dpa_with_cp( x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [q.grad, k.grad, v.grad, out] + for x in [dq, dk, dv, out] ] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq_, dk_, dv_, out_ = [ x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [q_.grad, k_.grad, v_.grad, out_] + for x in [dq_, dk_, dv_, out_] ] elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a5c345779..e3a4de73b 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1693,23 +1693,44 @@ def get_model(dtype, config): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_mha_fp8_vs_f16( + dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode +): """Test MultiHeadAttention module in FP8""" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + fp8_mha=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: - pytest.skip("Less than two backends to compare.") + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: available_backends, _, fused_attn_backends = get_available_attention_backends( config, @@ -1727,7 +1748,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1735,19 +1756,20 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm, RoPE, is_training + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-1 rmse_tol = 0.15 - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1758,6 +1780,8 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rmse_tol, True, ) + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1784,7 +1808,9 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): +def _run_mha_fp8_vs_f16( + dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe +): """Run MultiHeadAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() @@ -1794,15 +1820,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_mha, - fp8_mha=fp8_mha, - ) - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: @@ -1911,7 +1928,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] @@ -1927,16 +1945,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -1956,32 +1991,44 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe + ) + + if unfused_attn_supported: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training + dtype, config, False, qkv_layout, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1992,12 +2039,40 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol, True, ) + if unfused_attn_supported: + logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + unfused_attn_fwd_fp8, + fused_attn_fwd_f16, + "unfused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, + ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + unfused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"unfused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) if config.dropout_p != 0.0: # test cuDNN FP8 dropout assert torch.all( fused_attn_fwd_fp8 == 1 ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." else: + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, @@ -2021,9 +2096,10 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol, True, ) + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" -def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe): """Run DotProductAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() @@ -2033,14 +2109,6 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_dpa, - ) - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) with fp8_model_init(enabled=fp8_dpa): dpa = DotProductAttention( @@ -2147,6 +2215,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, + fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index c752d07d8..0f00b8b0e 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -14,6 +14,10 @@ get_device_compute_capability, get_cudnn_version, ) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils _current_file = pathlib.Path(__file__).resolve() @@ -27,6 +31,8 @@ torch.manual_seed(seed) torch.cuda.manual_seed(seed) +test_essential = True + model_configs_flash_attn = { # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA @@ -63,12 +69,22 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +dtypes = ["bf16", "fp16"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] + model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} + dtypes = ["bf16"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): @@ -77,6 +93,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": @@ -162,14 +179,30 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): } +dtypes = ["bf16", "fp16", "fp8"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} + dtypes = ["bf16", "fp8"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) -@pytest.mark.parametrize("fp8_mha", [False, True]) -def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) +@pytest.mark.parametrize("fp8_bwd", [True, False]) +@pytest.mark.parametrize("fp8_mha", [True, False]) +@pytest.mark.parametrize("fp8_dpa", [True, False]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("f16_O", [True, False]) +def test_cp_with_fused_attention( + dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -180,10 +213,15 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") + if dtype == "fp8" and not fp8_dpa and fp8_mha: + pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") + if dtype != "fp8" and fp8_bwd: + pytest.skip("Only fp8 works with fp8_bwd=True!") config = model_configs_fused_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type + if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": @@ -211,8 +249,22 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) - if dtype != "fp8" and fp8_mha: - pytest.skip("Only fp8 works with fp8_mha=True!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("Only fp8 works with scaling_mode != None!") + if dtype == "fp8" and scaling_mode is None: + pytest.skip("fp8 only works with scaling_mode != None!") + if ( + dtype == "fp8" + and scaling_mode == "current" + and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + ): + pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") + if f16_O and (dtype != "fp8" or scaling_mode != "current"): + pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: @@ -229,10 +281,25 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha ) dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + fp8_meta = {} + fp8_meta["recipe"] = None + fp8_meta["local_recipes"] = [] + fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha) + if fp8 and scaling_mode == "delayed": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)] + if fp8 and scaling_mode == "current": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [ + Float8CurrentScaling(fp8_dpa=True), + DelayedScaling(fp8_dpa=True), + ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, qkv_layout="_".join([qkv_format] * 3), + fp8=fp8, + fp8_meta=fp8_meta, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -246,7 +313,11 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha qkv_format=qkv_format, kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, log_level=pytest_logging_level, ), check=True, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 1d6435ad8..ba0f84578 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -129,7 +129,9 @@ void fused_attn_arbitrary_seqlen_fwd_impl( window_size_right, true, tensorType, - tensorType}; + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -585,7 +587,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( window_size_right, deterministic, tensorType, - tensorType}; + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 995dbda7f..21c544491 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); try { FADescriptor_v1 descriptor{b, @@ -1699,8 +1707,10 @@ void fused_attn_fp8_fwd_impl_v1( 0, 0, true, - fwd_tensor_type, - fwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1739,7 +1749,7 @@ void fused_attn_fp8_fwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -1787,7 +1797,13 @@ void fused_attn_fp8_fwd_impl_v1( descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() @@ -1839,11 +1855,12 @@ void fused_attn_fp8_fwd_impl_v1( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -1916,13 +1933,16 @@ void fused_attn_fp8_fwd_impl_v1( {descale_v, devPtrDescaleV}, {descale_s, devPtrDescaleS}, {scale_s, devPtrScaleS}, - {scale_o, devPtrScaleO}, {attn_scale, &scaling_factor}, {O, devPtrO}, {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; + if (is_delayed_scaling) { + variant_pack[scale_o] = devPtrScaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; } */ @@ -1963,8 +1983,9 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, + cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1978,6 +1999,15 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -2005,8 +2035,10 @@ void fused_attn_fp8_bwd_impl_v1( 0, 0, false, - fwd_tensor_type, - bwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + do_tensor_type, + dqkv_tensor_type}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2059,7 +2091,7 @@ void fused_attn_fp8_bwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -2099,7 +2131,8 @@ void fused_attn_fp8_bwd_impl_v1( o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + .set_stride(o_stride) + .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d}) @@ -2125,14 +2158,26 @@ void fused_attn_fp8_bwd_impl_v1( descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + if (is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() @@ -2214,10 +2259,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - dO->set_data_type(bwd_tensor_type); - dQ->set_data_type(bwd_tensor_type); - dK->set_data_type(bwd_tensor_type); - dV->set_data_type(bwd_tensor_type); + dO->set_data_type(do_tensor_type); + dQ->set_data_type(dqkv_tensor_type); + dK->set_data_type(dqkv_tensor_type); + dV->set_data_type(dqkv_tensor_type); std::tuple, // q std::shared_ptr, // k @@ -2298,14 +2343,10 @@ void fused_attn_fp8_bwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_o, devPtrDescaleO}, {descale_dO, devPtrDescaledO}, {descale_s, devPtrDescaleS}, {descale_dP, devPtrDescaledP}, {scale_s, devPtrScaleS}, - {scale_dQ, devPtrScaledQ}, - {scale_dK, devPtrScaledK}, - {scale_dV, devPtrScaledV}, {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, @@ -2316,6 +2357,15 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; + if (is_delayed_scaling) { + variant_pack[scale_dQ] = devPtrScaledQ; + variant_pack[scale_dK] = devPtrScaledK; + variant_pack[scale_dV] = devPtrScaledV; + } + if (!is_O_in_F16) { + variant_pack[descale_o] = devPtrDescaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { @@ -2366,6 +2416,7 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; + const DType O_type = output_O->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); size_t stride = 0; @@ -2432,8 +2483,8 @@ void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t ma attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, @@ -2467,6 +2518,7 @@ void fused_attn_fp8_bwd_qkvpacked( cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_QKV->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQKV->data.dtype; void* devPtrQKV = input_QKV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -2484,7 +2536,11 @@ void fused_attn_fp8_bwd_qkvpacked( void* devPtrDescaleV = input_QKV->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2527,7 +2583,8 @@ void fused_attn_fp8_bwd_qkvpacked( devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, @@ -2565,6 +2622,7 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; void* devPtrQ = input_Q->data.dptr; void* devPtrKV = input_KV->data.dptr; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); @@ -2633,8 +2691,8 @@ void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2671,6 +2729,7 @@ void fused_attn_fp8_bwd_kvpacked( cudnnHandle_t handle) { using namespace transformer_engine; const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; void* devPtrQ = input_Q->data.dptr; void* devPtrKV = input_KV->data.dptr; @@ -2688,7 +2747,11 @@ void fused_attn_fp8_bwd_kvpacked( void* devPtrDescaleV = input_KV->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2733,7 +2796,8 @@ void fused_attn_fp8_bwd_kvpacked( devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, @@ -2822,6 +2886,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -2831,8 +2896,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2878,7 +2943,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleV = input_Q->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2911,6 +2980,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; @@ -2924,7 +2994,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 0a0197423..f03774f8e 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -111,21 +111,24 @@ struct FADescriptor_v1 { std::int64_t window_size_left; std::int64_t window_size_right; bool deterministic; - cudnn_frontend::DataType_t fwd_tensor_type; - cudnn_frontend::DataType_t bwd_tensor_type; + cudnn_frontend::DataType_t qkv_tensor_type; + cudnn_frontend::DataType_t o_tensor_type; + cudnn_frontend::DataType_t do_tensor_type; + cudnn_frontend::DataType_t dqkv_tensor_type; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, - window_size_left, window_size_right, deterministic, bias_type, fwd_tensor_type, - bwd_tensor_type) < + window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, + o_tensor_type, do_tensor_type, dqkv_tensor_type) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, - rhs.fwd_tensor_type, rhs.bwd_tensor_type); + rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.dqkv_tensor_type); } }; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ea0287ef1..179d618b3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -209,6 +209,7 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " + f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) @@ -226,10 +227,11 @@ class Float8CurrentScaling(Recipe): pass. """ + use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" fp8_format: Format = Format.HYBRID - fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) @@ -238,9 +240,6 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert ( - not self.fp8_dpa and not self.fp8_mha - ), "FP8 attention is not supported for Float8CurrentScaling." def __repr__(self) -> str: return ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4a60bd9fe..f72c1eb9e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -16,14 +16,16 @@ import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - SplitAlongDim, get_device_compute_capability, - combine_tensors, split_tensor_along_dim, ) -from transformer_engine.pytorch.utils import attention_mask_func +from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.tensor.quantized_tensor import ( - QuantizedTensor, + QuantizedTensorBase, prepare_for_saving, restore_from_saved, ) @@ -40,7 +42,7 @@ META_O, META_QKV, ) -from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype +from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( @@ -53,6 +55,9 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -130,6 +135,58 @@ fa_utils.set_flash_attention_3_params() +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + + +class FP8EmulationFunc(torch.autograd.Function): + """ + Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: + - forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through) + - backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through) + """ + + @staticmethod + def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): + # pylint: disable=missing-function-docstring + if quantizer_name == "QKV_quantizer": + query_layer, key_layer, value_layer = [ + x.contiguous() for x in [tensor1, tensor2, tensor3] + ] + q_fp8, k_fp8, v_fp8 = combine_and_quantize( + qkv_layout, query_layer, key_layer, value_layer, quantizer + ) + tensors = combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + ) + elif quantizer_name in ["S_quantizer", "O_quantizer"]: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) + ctx.quantizer = quantizer + ctx.quantizer_name = quantizer_name + ctx.qkv_layout = qkv_layout + return tensors[0], tensors[1], tensors[2] + + @staticmethod + def backward(ctx, grad1, grad2, grad3): + # pylint: disable=missing-function-docstring + if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + elif ctx.quantizer_name == "dQKV_quantizer": + query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] + dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer + ) + tensors = combine_and_dequantize( + ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + ) + else: + tensors = grad1, grad2, grad3 + return tensors[0], tensors[1], tensors[2], None, None, None + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -189,6 +246,10 @@ def forward( alibi_slopes: Optional[torch.Tensor] = None, inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + fp8_output: bool = False, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -286,6 +347,35 @@ def forward( if apply_qk_layer_scaling: scale /= self.layer_number + if fp8: + # get quantizers from DPA; all Nones if not fp8 + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers) + ) + # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if fp8_recipe.float8_current_scaling(): + S_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=S_quantizer.dtype, device="cuda" + ) + dP_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=dP_quantizer.dtype, device="cuda" + ) + + if "2" in qkv_layout or "3" in qkv_layout: + qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) + qkv_layout = "_".join([qkv_format] * 3) + # quantize and dequantize QKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + ) + # quantize and dequantize dQKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + ) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -330,6 +420,12 @@ def forward( dtype=query_layer.dtype ) + if fp8: + # quantize and dequantize dP to emulate FP8 + matmul_result, *_ = FP8EmulationFunc.apply( + matmul_result, None, None, dP_quantizer, "dP_quantizer", None + ) + # add attention sink to the last column: [b, np, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( @@ -379,6 +475,12 @@ def forward( # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + if fp8: + # quantize and dequantize S to emulate FP8 + attention_probs, *_ = FP8EmulationFunc.apply( + attention_probs, None, None, S_quantizer, "S_quantizer", None + ) + # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -413,6 +515,20 @@ def forward( # [tq, np, hn] --> [tq, hp] context_layer = context_layer.view(total_tokens, -1) + if fp8: + # quantize and dequantize O to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, O_quantizer, "O_quantizer", None + ) + # quantize and dequantize dO to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, dO_quantizer, "dO_quantizer", None + ) + + # quantize O + if fp8_output: + context_layer = O_quantizer(context_layer) + return context_layer @@ -511,6 +627,7 @@ def forward( quantizers=None, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), + fp8_output: bool = False, ) -> torch.Tensor: """flash-attn fprop""" @@ -716,6 +833,7 @@ def forward( quantizers=quantizers, pad_between_seqs=False, use_flash_attn_3=use_flash_attn_3, + fp8_output=fp8_output, ) else: from transformer_engine.pytorch.cpu_offload import ( @@ -815,8 +933,6 @@ def convert_to_torch_float8(tensor, dtype): ) return out - # "fp8_mha" decides outputs in fp8, while inputs are inferred from - # the real dtype assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." @@ -863,7 +979,7 @@ def convert_to_torch_float8(tensor, dtype): if fp8: output = output.to(dtype=torch_orig_dtype) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: O_quantizer = quantizers["scaling_fwd"][META_O] output = O_quantizer(output) @@ -891,7 +1007,7 @@ def convert_to_torch_float8(tensor, dtype): if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: output_data = ( output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) .transpose(0, 1) @@ -915,7 +1031,7 @@ def convert_to_torch_float8(tensor, dtype): class FusedAttnFunc(torch.autograd.Function): - """Function for FusedAttention with separate Q, K, V tensors""" + """FusedAttention forward and backward implementation""" @staticmethod def forward( @@ -949,55 +1065,71 @@ def forward( quantizers, deterministic, softmax_offset, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn - fake_dtype = q.dtype + # add NVTX range + nvtx_label = "transformer_engine.FusedAttnFunc.forward" + nvtx_range_push(f"{nvtx_label}") + + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + + # input types are inferred from the real data while output types are controlled by fp8_output + # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + + # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel in FP8: + is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + + # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + # get nominal data type for out + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + out_nominal_dtype = q.dtype + if fp8: fused_attention_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - q_fp8, k_fp8, v_fp8 = None, None, None + # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - match qkv_group: - case 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = QKV_quantizer(qkv) - q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) - case 2: - q_fp8 = QKV_quantizer(q) - dim = qkv_layout.split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = QKV_quantizer(kv_c) - k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) - case 3: - q_fp8 = QKV_quantizer(q) - k_fp8 = QKV_quantizer(k) - v_fp8 = QKV_quantizer(v) - case _: - raise "Invalid qkv_layout " + qkv_layout - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - out_fp8, aux_ctx_tensors = fused_attn_fwd( + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1006,7 +1138,7 @@ def forward( q_fp8, k_fp8, v_fp8, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1026,42 +1158,54 @@ def forward( rng_gen, softmax_offset, ) - if is_output_fp8: - out_ret = out_fp8 + + # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_fp8 = out_ + out = out_ + + if isinstance(out_, Float8Tensor): + if not is_output_fp8 or not is_bwd_fp8: + out = out_.dequantize().view(out_.shape) else: - out_ret = out_fp8.dequantize().view(out_fp8.shape) - # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 - # is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn - out_save = out_ret + if is_output_fp8 or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_fp8 = O_quantizer(out_) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - # 1: qkv packed, 2: kv packed, 3: qkv separate + # return appropriate tensors + out_ret = out_fp8 if is_output_fp8 else out + + # save appropriate tensors + fp8_tensors = (None, None, None, None) + qkvo_tensors = (None, None, None, None) + if is_bwd_fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + qkvo_tensors = (None, None, None, out) + else: + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + else: if is_input_fp8: - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) - q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) - if qkv_group == 2: - q = q.dequantize() - dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = kv.dequantize() - k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) - if qkv_group == 3: - q = q.dequantize() - k = k.dequantize() - v = v.dequantize() - if is_output_fp8: - out_save = out_fp8.dequantize() - - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + qkvo_tensors = (q, k, v, out) else: - # q, k, v, out_ret: torch.float16 or torch.bfloat16 - out_ret, aux_ctx_tensors = fused_attn_fwd( + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1070,7 +1214,7 @@ def forward( q, k, v, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1090,10 +1234,18 @@ def forward( rng_gen, softmax_offset, ) - out_save = out_ret + out = out_ + out_ret = out_ fp8_tensors = (None, None, None, None) + qkvo_tensors = (q, k, v, out) + + nvtx_range_pop(f"{nvtx_label}") - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = is_bwd_fp8 + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + ctx.nominal_dtype = out_nominal_dtype from transformer_engine.pytorch.cpu_offload import ( CPUOffloadEnabled, @@ -1104,7 +1256,7 @@ def forward( if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out_save] + tensor_list = [q, k, v, out] qkv_layout = "sbhd_sbhd_sbhd" mark_activation_offload(*tensor_list) @@ -1112,7 +1264,6 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *qkvo_tensors, @@ -1126,11 +1277,14 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta + ctx.layer_number = layer_number + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.S_quantizer = S_quantizer - if ctx.fp8: + if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer): ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() @@ -1155,17 +1309,15 @@ def forward( @staticmethod def backward(ctx, d_out): # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2 - fake_dtype = d_out.dtype - d_out = d_out.contiguous() + # d_out is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase): + d_out = ctx.dO_quantizer(d_out) + if not ctx.use_FAv2_bwd: + d_out._data = d_out._data.contiguous() + elif not ctx.use_FAv2_bwd: + d_out = d_out.contiguous() ( q_fp8, k_fp8, @@ -1219,16 +1371,55 @@ def backward(ctx, d_out): dk = dk[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn"): + with torch.cuda.nvtx.range("FusedAttnFunc.backward"): + # get nominal data type of dq, dk, dv + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + dqkv_nominal_dtype = ctx.nominal_dtype + if ctx.fp8: + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 if ctx.is_output_fp8: d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - dqkv_dtype = TE_DType[d_out_fp8._data.dtype] - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 - dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # get tex.DType for dq, dk, dv data + dqkv_te_dtype = d_out_fp8._fp8_dtype + + # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # fp8_dtype = tex.DType.kFloat8E4M3 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # + # dq_, dk_, dv_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ) + dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -1236,10 +1427,10 @@ def backward(ctx, d_out): q_fp8, k_fp8, v_fp8, - out_fp8, + out_, d_out_fp8, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1258,40 +1449,40 @@ def backward(ctx, d_out): ctx.deterministic, ) - # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 - # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 - if not ctx.is_input_fp8: - qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = ctx.qkv_layout.find("3") - dqkv_fp8_data = combine_tensors( - [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim - ) - dqkv_fp8 = dq_fp8.make_like( - tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape - ) - dqkv = dqkv_fp8.dequantize() - dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) - if qkv_group == 2: - dq = dq_fp8.dequantize() - dim = ctx.qkv_layout.split("_")[1].find("2") - dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] - ) - dkv = dkv_c_fp8.dequantize() - dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True) - if qkv_group == 3: - dq = dq_fp8.dequantize() - dk = dk_fp8.dequantize() - dv = dv_fp8.dequantize() - else: - dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 + # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + dq, dk, dv = dq_, dk_, dv_ + is_float8tensor = isinstance(dq_, Float8Tensor) + if is_float8tensor and not ctx.is_input_fp8: + # return in F16 + dq, dk, dv = combine_and_dequantize( + ctx.qkv_layout, + dq_, + dk_, + dv_, + src_nominal_dtype=dq_.dtype, + ) + if not is_float8tensor and ctx.is_input_fp8: + # return in FP8 + dq, dk, dv = combine_and_quantize( + ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + ) + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) else: - if isinstance(d_out, QuantizedTensor): - d_out = d_out.dequantize() - dqkv_dtype = TE_DType[d_out.dtype] - # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 + if isinstance(d_out, QuantizedTensorBase): + d_out = d_out.dequantize(dtype=ctx.nominal_dtype) + dqkv_te_dtype = TE_DType[d_out.dtype] + # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1302,8 +1493,8 @@ def backward(ctx, d_out): v, out, d_out, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1358,6 +1549,8 @@ def backward(ctx, d_out): None, None, d_softmax_offset, + None, + None, ) @@ -1463,6 +1656,7 @@ def forward( pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, softmax_offset: torch.Tensor = None, + fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -1563,15 +1757,27 @@ def forward( ) if fp8: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" " is required for FP8 attention!" ) assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" - assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across TP+CP group is necessary when using context parallelism with" - " FP8!" - ) + if fp8_recipe.delayed(): + assert not context_parallel or fp8_recipe.reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism" + " with FP8!" + ) + if fp8_recipe.float8_current_scaling() and context_parallel: + all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + for q in all_quantizers: + if isinstance(q, Float8CurrentScalingQuantizer): + q.with_amax_reduction = True + q.amax_reduction_group = ( + cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group + ) if context_parallel: assert ( @@ -1615,6 +1821,8 @@ def forward( pad_between_seqs=pad_between_seqs, softmax_type=self.softmax_type, softmax_offset=softmax_offset, + fp8_output=fp8_output, + layer_number=self.layer_number, ) else: with self.attention_dropout_ctx(): @@ -1648,6 +1856,8 @@ def forward( quantizers, self.deterministic, softmax_offset, + fp8_output, + self.layer_number, ) # ...hd -> ...(hd) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 2e4b6b617..539caffbb 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -9,7 +9,6 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - combine_tensors, get_cudnn_version, nvtx_range_pop, nvtx_range_push, @@ -20,7 +19,9 @@ fused_attn_bwd, FusedAttnBackend, ) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( dist_group_type, @@ -41,6 +42,9 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) _cu_seqlens_info_with_cp_cache = {} @@ -48,6 +52,9 @@ _seq_chunk_ids_cache_for_reordering_after_attn = {} _softmax_offset_chunk_ids_cache = {} +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + def flash_attn_p2p_communicate( rank, send_tensor, send_dst, recv_tensor, recv_src, cp_group, batch_p2p_comm @@ -226,11 +233,11 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): @jit_fuser def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -240,13 +247,13 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz @jit_fuser def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -278,16 +285,16 @@ def flash_attn_a2a_communicate( x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # or [s, b, h, d] -> [s, b, cp, h//cp, d] x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] a2a_inputs[i] = x.movedim(-3, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -298,8 +305,8 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -309,11 +316,11 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs @@ -467,6 +474,585 @@ def get_fa_args( ] +def cp_p2p_fwd_prepare_qkv( + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + step, + cp_size, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P forward""" + cu_seqlens_q_per_step = None + cu_seqlens_kv_per_step = None + if section in ["diagonal", "all"]: + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + rank_ = rank if section == "diagonal" else (rank - step) % cp_size + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank_, True, True + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, k_part, v_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part]] + + elif section == "lower-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + False, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv_half + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part = q_part.view(q_part.shape[0], -1, *q_part.shape[-2:]) + # [b, 2, sk//2, h, d] -> [b, sk//2, h, d] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part = q_part.view(-1, *q_part.shape[-3:]) + # [2, sk//2, b, h, d] -> [sk//2, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + + elif section == "upper-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // (cp_size * 2) + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q_half + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part = q_part[:, 1, ...] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part = q_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part = tex.thd_read_half_tensor(q_part, cu_seqlens_q_padded, 1) + + return q_part, k_part, v_part, cu_seqlens_q_per_step, cu_seqlens_kv_per_step + + +def cp_p2p_fwd_fused_attn( + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step, + O_quantizer_per_step, + rank, + step, + cp_size, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FusedAttention backend""" + attn_bias_inputs = None + max_seqlen_q_ = None + max_seqlen_kv_ = None + cu_seqlens_q_ = None + cu_seqlens_kv_ = None + attn_mask_type_ = None + cu_seqlens_q_padded_ = None + cu_seqlens_kv_padded_ = None + if section in ["diagonal", "all"]: + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + elif section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = attn_bias[..., idx, :].contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = ( + cu_seqlens_kv_padded // 2 if cu_seqlens_kv_padded is not None else None + ) + elif section == "upper-triangle": + q_part = q_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q // 2 + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded // 2 if cu_seqlens_q_padded is not None else None + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step + fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step + + out_per_step, aux_ctx_tensors = fused_attn_fwd( + is_training, + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_, + cu_seqlens_kv_, + q_part, + k_part, + v_part, + fake_dtype=fwd_nominal_dtype, + fused_attention_backend=fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + **fp8_meta_kwargs, + ) + + if fp8: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors + attn_bias = rest[0] if len(rest) > 0 else None + + return out_per_step, softmax_lse_per_step, rng_states, attn_bias + + +def cp_p2p_fwd_flash_attn( + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FlashAttention backend""" + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + causal_ = False + if section in ["diagonal", "all"]: + causal_ = section == "diagonal" + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + if section in ["lower-triangle", "upper-triangle"]: + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_, + cu_seqlens_kv=cu_seqlens_kv_, + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + ) + fa_outputs = flash_attn_fwd( + q_part, + k_part, + v_part, + *fa_forward_args_thd, + causal=causal_, + **fa_forward_kwargs, + ) + rng_states = None + if not fa_utils.v2_7_0_plus: + out_per_step = fa_outputs[4] + softmax_lse_per_step = fa_outputs[5] + if not use_flash_attn_3: + rng_states = fa_outputs[7] + else: + out_per_step = fa_outputs[0] + softmax_lse_per_step = fa_outputs[1] + if not use_flash_attn_3: + rng_states = fa_outputs[3] + + return out_per_step, softmax_lse_per_step, rng_states + + +def cp_p2p_bwd_prepare_qkv( + q_part, + k_part, + v_part, + out_part, + dout_part, + qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P backward""" + if section in ["diagonal", "all"]: + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) + for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif section == "lower-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, out_part, dout_part] + ] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, out_part, dout_part] + ] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + elif section == "upper-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part, out_part, dout_part = q_part[:, 1], out_part[:, 1], dout_part[:, 1] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part, out_part, dout_part = q_part[1], out_part[1], dout_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part, out_part, dout_part = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q_part, out_part, dout_part] + ] + + return q_part, k_part, v_part, out_part, dout_part + + +def cp_p2p_bwd_fused_attn( + fp8, + fp8_recipe, + q_fp8, + kv_fp8, + out_fp8, + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + max_seqlen_q, + max_seqlen_kv, + step, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + deterministic, + fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + S_quantizer, + dP_quantizer_per_step, + dQKV_quantizer_per_step, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FusedAttention backend""" + if fp8: + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + attn_mask_type_ = attn_mask_type + + if section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_kv_padded_ = None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + elif section == "upper-triangle": + q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] + if fp8: + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q // 2 + cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + + if attn_dbias is not None: + aux_tensors += [attn_biases[cp_size - step - 1]] + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step + + dq, dk, dv, dbias, *_ = fused_attn_bwd( + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv_per_step[cp_size - step - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + bwd_nominal_dtype, + bwd_output_te_dtype, + aux_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + deterministic=deterministic, + **fp8_meta_kwargs, + ) + + return dq, dk, dv, dbias + + +def cp_p2p_bwd_flash_attn( + use_flash_attn_3, + qkv_format, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + step, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FlashAttention backend""" + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - step - 1] + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + softmax_lse__ = softmax_lse + causal_ = False + if section == "diagonal": + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, 0) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + causal_ = True + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + softmax_lse__ = softmax_lse_ + + fa_backward_args_thd = get_fa_args( + False, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + dq=dq, + dk=dk, + dv=dv, + ) + flash_attn_bwd( + dout_part, + q_part, + k_part, + v_part, + out_part, + softmax_lse__, + *fa_backward_args_thd, + causal=causal_, + **fa_backward_kwargs, + ) + + return dq, dk, dv + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -508,30 +1094,24 @@ def forward( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") - enable_mla = k.shape[-1] != v.shape[-1] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.forward" + nvtx_range_push(f"{nvtx_label}") + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 if isinstance(cp_group, list): - assert ( - qkv_format != "thd" - ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" - assert attn_bias_type == "no_bias", ( - f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" - " yet!" - ) cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) cp_group = cp_group[1] - else: - cp_group_a2a = None - cp_size_a2a = 1 - rank_a2a = 0 - cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] @@ -541,18 +1121,19 @@ def forward( device_compute_capability < (10, 0) and cp_size == 2 ) + # set up attention args + enable_mla = k.shape[-1] != v.shape[-1] causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") - if enable_mla: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - else: - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None if use_fused_attention: batch_dim = qkv_format.index("b") @@ -563,7 +1144,6 @@ def forward( q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv ) else: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size @@ -573,79 +1153,110 @@ def forward( cu_seqlens_kv_per_step = [None for _ in range(cp_size)] fused_attn_backend = None - qkv_dtype = q.dtype amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] - O_CP_quantizer_per_step = [None for _ in range(cp_size)] - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False + O_quantizer_per_step = [None for _ in range(cp_size)] + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + fwd_nominal_dtype = q.dtype + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] ( QKV_quantizer, O_quantizer, - O_CP_quantizer, S_quantizer, dQKV_quantizer, - dQKV_CP_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + q_f16 = None + q_fp8, k_fp8, v_fp8 = (None, None, None) + # communicate for the 'a2a' part of 'a2a+p2p' + if cp_size_a2a > 1: + if fp8 and is_input_fp8: + QKV_quantizer = q._quantizer + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = (q._data, k._data, v._data) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if fp8 and is_input_fp8: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) + ] + q, k, v = q_fp8, k_fp8, v_fp8 + + # convert qkv to the right type if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha - if is_input_fp8: - QKV_quantizer = q._quantizer - q, k, v = q._data, k._data, v._data - else: - q_f16, k_f16, v_f16 = q, k, v - if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16)._data - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - # partial result quantizer - for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() - O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if is_input_fp8: + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] else: - assert False, "FP8 is only supported with Fused Attention!" + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_f16 = q + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # amax_per_step[0]: amax_s x cp_size + # amax_per_step[1]: amax_o x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i] = O_quantizer.copy() + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] - if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) - - q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True - ) - if not fp8: - q_f16 = q - elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16 = q - q = QKV_quantizer(q_f16)._data - + # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" if causal: if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] + # [s, b, h, d] -> [2, s//2, b, h, d] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -654,7 +1265,7 @@ def forward( assert ( attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 ), "Sequence length does not meet divisible requirements!" - # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-2], 2, @@ -662,12 +1273,14 @@ def forward( 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size), ) - # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" + # stats tensor shape: + # BHS1 before cuDNN 9.6 or flash-attention v2.6/v3 + # TH1 after cuDNN 9.6 or flash-attention v2.6/v3 softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: @@ -675,7 +1288,9 @@ def forward( else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + # set up args for FlashAttention backend flash_attn_fwd = None + fa_forward_kwargs = {} if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_flash_attn_3: @@ -714,11 +1329,9 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - # Flash Attn inputs + # set up inputs for forward q_inputs = [None, None] kv_inputs = [None, None] - attn_bias_inputs = [None, None] - # Flash Attn outputs out_per_step = [None for _ in range(cp_size)] softmax_lse_per_step = [None for _ in range(cp_size)] rng_states = [None for _ in range(cp_size)] @@ -730,19 +1343,15 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if enable_mla: - # If MLA, the shape of k and v does not match, so we flatten them - # and split them after receiving them. - k_shape = k.shape - k_numel = k.numel() - v_shape = v.shape - p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) - elif qkv_format in ["bshd", "sbhd"]: - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) - else: # qkv_format == "thd" - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] + # P2P communication and compute: each rank has cp_size steps + # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None for i in range(cp_size + 1): if i < cp_size: @@ -763,634 +1372,205 @@ def forward( batch_p2p_comm, ) - if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i] + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + q_part = q + + prepare_inputs = [ + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + i, + cp_size, + ] + if use_fused_attention: + fused_attn_inputs = [ + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step[i], + O_quantizer_per_step[i], + rank, + i, + cp_size, + ] else: - # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data - if enable_mla: - # If MLA, k and v are flattened, so split them after receiving. - k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) - v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + flash_attn_inputs = [ + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + ] + + # cp_size = 4: + # + # step + # section | 0 1 2 3 + # -------------------- + # G 0 | d, u, u, u, + # P 1 | l, d, u, u, + # U 2 | l, l, d, u, + # 3 | l, l, l, d, + # + # Each GPU holds a slice of Q and KV. To compute the attention of each Q slice, each GPU + # runs cp_size steps to get the partial results of its own Q and all KV slices. KV is communicated + # in a point-to-point, ring fashion. For attn_mask_type = causal, there are three attention + # patterns in the cp_size x cp_size (i.e. GPU x step) matrix, the diagonal tiles, the lower-triangle + # tiles, and the upper-triangle tiles. For attn_mask_type != causal, the pattern is all the same. if causal: if i == 0: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q + section = "diagonal" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - fake_dtype=qkv_dtype, - fused_attention_backend=fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=True, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] elif i <= rank: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - False, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] - k_part = k_part[:, 0, ...] - v_part = v_part[:, 0, ...] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0] - elif qkv_format == "thd": - q_inputs[i % 2] = q - if enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor( - k_part, cu_seqlens_kv_padded, 0 - ) - v_part = tex.thd_read_half_tensor( - v_part, cu_seqlens_kv_padded, 0 - ) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + section = "lower-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv // 2, ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, - ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q_half - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...] - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1] - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + section = "upper-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - q_inputs[i % 2] = q_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias_[..., 1, :, idx, :], - attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q // 2, - max_seqlen_kv=max_seqlen_kv, ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, - ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv + # all tiles + section = "all" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, - ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q, - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] + # softmax_lse correction if i > 0: - # wait until fwd restuls correction of last step is done + # wait until fwd results correction of last step is done if i > 1: flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] or - # [t, np, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] or + # [t, h, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() ) if fp8: - out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) + # dequantize out_per_step to torch.float32 + if fp8_recipe.delayed(): + out_per_step[i - 1] = out_per_step[i - 1].dequantize( + dtype=torch.float32 + ) + if fp8_recipe.float8_current_scaling(): + out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) + if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": @@ -1430,6 +1610,7 @@ def forward( if causal and rank < (cp_size - 1): second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: @@ -1482,7 +1663,6 @@ def forward( softmax_lse_in_packed_format, ) - kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] @@ -1497,39 +1677,84 @@ def forward( ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] + # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] + # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) + # update FP8 quantizers: amax across cp_size steps if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) - O_CP_quantizer.amax.copy_(amax_cp_fwd[1]) + O_quantizer.amax.copy_(amax_cp_fwd[1]) - out_fp8 = None - out_f16 = out.to(qkv_dtype) - - if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = O_quantizer(out_f16) # final result + if fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + # prepare for return and ctx saves + out_fp8 = None + out_f16 = out.to(fwd_nominal_dtype) + if fp8 and ( + is_output_fp8 + or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + ): + out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8._data + ctx.layer_number = layer_number + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = fp8 and is_bwd_fp8 + + kv_fp8 = None + kv = p2p_comm_buffers[-1] + if fp8: + q_fp8, kv_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8], [q, kv]) + ] + # q, kv, out + fp8_tensors = (None, None, None) + f16_tensors = (None, None, None) + if ctx.fp8: + # fwd: fp8, bwd: fp8, save all fp8 + fp8_tensors = (q_fp8, kv_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + f16_tensors = (None, None, out_f16) elif fp8 and is_input_fp8: - q_save, kv_save, out_save = q, kv, out_f16 + # fwd: fp8, bwd: f16, save all f16 + # dequantize fp8 inputs + q_f16 = q_fp8.dequantize() + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8: + # fwd: fp8, bwd: f16, save all f16 + # inputs are already in f16 + q_f16 = q_f16.view(q.shape) + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) else: + # fwd: f16, bwd: f16, save all f16 + # inputs and kernels are both f16 q_f16 = q_f16.view(q.shape) - q_save, kv_save, out_save = q_f16, kv, out_f16 + kv_f16 = kv + f16_tensors = (q_f16, kv_f16, out_f16) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - kv_save, - out_save, + *fp8_tensors, + *f16_tensors, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, @@ -1559,21 +1784,18 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 ctx.enable_mla = enable_mla - if enable_mla: - ctx.k_numel = k_numel - ctx.k_shape = k_shape - ctx.v_shape = v_shape + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer - ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.QKV_quantizer = QKV_quantizer @@ -1586,17 +1808,31 @@ def forward( ctx.O_quantizer.scale = O_quantizer.scale.clone() ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + + nvtx_range_pop(f"{nvtx_label}") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.backward" + nvtx_range_push(f"{nvtx_label}") + + # dout is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase): + dout = ctx.dO_quantizer(dout) + if ctx.use_fused_attention: + dout._data = dout._data.contiguous() + elif ctx.use_fused_attention: + dout = dout.contiguous() + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a - cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] @@ -1606,33 +1842,38 @@ def backward(ctx, dout): device_compute_capability < (10, 0) and cp_size == 2 ) - q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( - restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - ) + # get saved tensors + ( + q_fp8, + kv_fp8, + out_fp8, + q, + kv, + out, + softmax_lse, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] rng_states = other_tensors[cp_size * 2 : cp_size * 3] attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + # set up attention args causal = "causal" in ctx.attn_mask_type - padding = "padding" in ctx.attn_mask_type - seq_dim = None + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") - if ctx.enable_mla: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + # set up attention bias if attn_biases[0] is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] ) @@ -1640,6 +1881,7 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None + # set up softmax_lse softmax_lse_ = None if causal and ctx.second_half_lse_seqlen is not None: if ctx.qkv_format == "thd": @@ -1650,86 +1892,124 @@ def backward(ctx, dout): ctx.second_half_lse_seqlen, ) else: - # [b, np, sq] -> [b, np, 2, sq//2] + # [b, h, sq] -> [b, h, 2, sq//2] softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() - # [b, np, sq//2] -> [b, np, sq//2, 1] or - # [t//2, np] -> [t//2, np, 1] + # [b, h, sq//2] -> [b, h, sq//2, 1] or + # [t//2, np] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() - # [b, np, sq] -> [b, np, sq, 1] or - # [t, np] -> [t, np, 1] + # [b, h, sq] -> [b, h, sq, 1] or + # [t, np] -> [t, h, 1] softmax_lse.unsqueeze_(-1) - dout = dout.contiguous() - dq = None - dout_dtype = dout.dtype + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + bwd_nominal_dtype = ctx.fwd_nominal_dtype + + # convert out, dout to the right type fused_attn_backend = None - fused_attn_dqkv_dtype = None amax_per_step = None dP_quantizer_per_step = [None for _ in range(cp_size)] - dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_quantizer_per_step = [None for _ in range(cp_size)] + buffer_dtype = torch.uint8 + dq_buffer = None + dout_fp8 = None + bwd_output_te_dtype = None + dkv_buffer = None if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: - dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device) - dkv_fp8 = torch.empty( - (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device - ) - dkv_fp8_ = torch.empty_like(dkv_fp8) - p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] - dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() - dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype + # dout: torch.Tensor, dtype=torch.uint8 + if ctx.is_output_fp8: + dout_fp8 = dout else: - assert False, "FP8 is only supported with Fused Attention!" + dout_fp8 = ctx.dO_quantizer(dout) + dout = dout_fp8._data + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # dout_fp8._fp8_dtype + bwd_output_te_dtype = ctx.dO_quantizer.dtype + + # create buffers for reduction in float32 + if ctx.fp8_recipe.delayed(): + dq_buffer = torch.empty( + (cp_size, *q.shape), + dtype=buffer_dtype, + device=q.device, + ) + if ctx.fp8_recipe.float8_current_scaling(): + dq_buffer = torch.empty( + q.shape, + dtype=torch.float32, + device=q.device, + ) + kv_recv_buffer = torch.empty_like(kv) + dkv_send_buffer = torch.empty( + (cp_size, *kv.shape), + dtype=buffer_dtype, + device=kv.device, + ) + dkv_recv_buffer = torch.empty_like(dkv_send_buffer) + p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] + if ctx.fp8_recipe.float8_current_scaling(): + dkv_buffer = torch.zeros( + kv.shape, + dtype=torch.float32, + device=kv.device, + ) + + # amax_per_step[0]: amax_dp x cp_size + # amax_per_step[1]: amax_dqkv x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: - if ctx.fp8_meta is not None: - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - kv = ctx.QKV_quantizer.create_tensor_from_data( - kv, fake_dtype=ctx.qkv_dtype, internal=True - ) - q = q.dequantize(dtype=ctx.qkv_dtype) - kv = kv.dequantize(dtype=ctx.qkv_dtype) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - if cp_size_a2a == 1: - dout = dout.dequantize(dtype=dout_dtype) - else: - ctx.dO_quantizer = dout._quantizer - dout = dout._data - dq = torch.empty_like(q) + if isinstance(dout, QuantizedTensorBase): + dout = dout.dequantize(dtype=bwd_nominal_dtype) + dq_buffer = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) @@ -1746,11 +2026,6 @@ def backward(ctx, dout): ctx.cp_stream, True, ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - dout = dout.dequantize(dtype=dout_dtype) if ctx.enable_mla: out = out.view(*ctx.v_shape) @@ -1759,7 +2034,6 @@ def backward(ctx, dout): # MHA or GQA out = out.view(*q.shape) dout = dout.view(*q.shape) - send_recv_reqs = [] flash_attn_bwd = None if not ctx.use_fused_attention: @@ -1794,6 +2068,7 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + send_recv_reqs = [] for i in range(cp_size): # wait until KV is received for req in send_recv_reqs: @@ -1814,8 +2089,8 @@ def backward(ctx, dout): ) else: dkv_a2a_req = torch.distributed.all_to_all_single( - dkv_fp8, - dkv_fp8_, + dkv_send_buffer, + dkv_recv_buffer, group=ctx.cp_group, async_op=True, ) @@ -1832,593 +2107,146 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None - if ctx.enable_mla: - k_part = kv[: ctx.k_numel].view(*ctx.k_shape) - v_part = kv[ctx.k_numel :].view(*ctx.v_shape) - # In reversed order of fwd - if causal: - if i == (cp_size - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout - if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) + q_part, out_part, dout_part = q, out, dout - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, - ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = 0 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=True, - **fa_backward_kwargs, - ) - elif i >= (cp_size - rank - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part[:, 0] - v_part = v_part[:, 0] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0] - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0] - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - if ctx.enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) - v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - if ctx.use_fused_attention: - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_ = kv_.contiguous() - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ + prepare_inputs = [ + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ] + if ctx.use_fused_attention: + fused_attn_inputs = [ + ctx.fp8, + ctx.fp8_recipe, + q_fp8, + kv_fp8, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ), + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + i, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + ctx.softmax_scale, + ctx.dropout_p, + qkv_layout, + ctx.attn_mask_type, + ctx.attn_bias_type, + ctx.deterministic, + ctx.fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + ctx.S_quantizer, + dP_quantizer_per_step[i], + dQKV_quantizer_per_step[i], + ] + else: + flash_attn_inputs = [ + ctx.use_flash_attn_3, + ctx.qkv_format, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + i, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + ] - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 - ), - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, + # there are still three sections in these tiles based on their attention pattern + # for attn_mask_type = causal, and one for attn_mask_type != causal. + if causal: + if i == (cp_size - 1): + section = "diagonal" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv // 2, - dq=dq_, - dk=dk_, - dv=dv_, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + elif i >= (cp_size - rank - 1): + section = "lower-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section + ) + else: + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_, out_, dout_ = q[1], out[1], dout[1] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_, out_, dout_ = [ - tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) - for x in [q, out, dout] - ] - kv_ = kv + section = "upper-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q // 2, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse_, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: + section = "all" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q - if not ctx.enable_mla: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - out_part = out - dout_part = dout - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_, *_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - dkv_ = torch.empty_like(kv) - dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] - dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout, - q, - k_part, - v_part, - out, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq = dq_fp8[(rank + i + 1) % cp_size] + # dq, dk, dv are reduced across steps in higher precision + # DelayedScaling: collect all results in uint8 to one tensor, dequantize to float32, then reduce + # CurrentScaling: dequantize partial results from each step to float32, then reduce + if ctx.fp8 and ctx.use_fused_attention: + if ctx.fp8_recipe.delayed(): + dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] + if ctx.fp8_recipe.float8_current_scaling(): + dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] + + # copy dq_ into the right buffer position + # buffer is cp_size x dq_size for DelayedScaling and the same size as dq for CurrentScaling + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dq = dq_buffer[(rank + i + 1) % cp_size] + else: + dq = dq_buffer if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): - # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or - # [sq, b, np, hn] -> [2, sq//2, b, np, hn] + # [b, sq, h, d] -> [b, 2, sq//2, h, d] or + # [sq, b, h, d] -> [2, sq//2, b, h, d] dq_ = dq_.view(*dq.shape) - - if ctx.fp8: + if ctx.fp8 and ctx.fp8_recipe.delayed(): if i >= (cp_size - rank - 1) or not causal: dq.copy_(dq_) else: @@ -2428,6 +2256,8 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq[0].fill_(0) dq[1].copy_(dq_) + else: + dq.copy_(dq_) elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) @@ -2463,18 +2293,19 @@ def backward(ctx, dout): else: dq.add_(dq_) + # dbias correction if attn_dbias is not None: idx = (rank + i + 1) % cp_size if i == (cp_size - 1) or not causal: - # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + # [b, h, sq, sk//cp] -> [b, h, sq, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) elif i >= (cp_size - rank - 1): - # [b, np, sq, sk//(2*cp)] + # [b, h, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) else: - # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + # [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) @@ -2483,254 +2314,159 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - if ctx.fp8: - if i < cp_size - 1: - dkv = dkv_fp8_[(rank + i + 1) % cp_size] - else: - dkv = dkv_fp8[(rank + i + 1) % cp_size] + # dkv correction + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] + elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] - if ctx.use_fused_attention: - if ctx.enable_mla: - dkv_ = None - elif ctx.qkv_format in ["bshd", "sbhd"]: - dkv_ = combine_tensors([dk_, dv_], -2) - elif ctx.qkv_format == "thd": - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment - if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - dkv_ = dkv_.movedim(-3, 0) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv_ = dkv_.view(*dkv.shape) - - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] or - # [2, sk//2, b, np, hn] - dk = dkv[: ctx.k_numel].view(*ctx.k_shape) - dv = dkv[ctx.k_numel :].view(*ctx.v_shape) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - dk_ = dk_.view(*ctx.k_shape) - dv_ = dv_.view(*ctx.v_shape) - - if ctx.fp8: - # enable_mla and fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dk[:, 1, ...].fill_(0) - dv[:, 0, ...].copy_(dv_) - dv[:, 1, ...].fill_(0) - elif ctx.qkv_format == "sbhd": - dk[0].copy_(dk_) - dk[1].fill_(0) - dv[0].copy_(dv_) - dv[1].fill_(0) - else: - dk.copy_(dk_) - dv.copy_(dv_) - elif causal: - # enable_mla and not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_[:, 0, ...]) - dk[:, 1, ...].copy_(dk_[:, 1, ...]) - dv[:, 0, ...].add_(dv_[:, 0, ...]) - dv[:, 1, ...].copy_(dv_[:, 1, ...]) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_[0, ...]) - dk[1, ...].copy_(dk_[1, ...]) - dv[0, ...].add_(dv_[0, ...]) - dv[1, ...].copy_(dv_[1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "copy" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dk.add_(dk_) - dv.add_(dv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dv[:, 0, ...].copy_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].copy_(dk_) - dv[0, ...].copy_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "copy", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_) - dv[:, 0, ...].add_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_) - dv[0, ...].add_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dk.add_(dk_) - dv.add_(dv_) - else: # i == 0 + + # [b, 2, sk//2, h, d] or + # [2, sk//2, b, h, d] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8 and ctx.fp8_recipe.delayed(): + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: dk.copy_(dk_) dv.copy_(dv_) else: - # enable_mla and not fp8 and not causal - if i == 0: - dk.copy_(dk_) - dv.copy_(dv_) - else: # i > 0 + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "copy") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "copy") + else: dk.add_(dk_) dv.add_(dv_) - else: - if ctx.fp8: - # fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - dkv[:, :, 1, ...].fill_(0) + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - dkv[:, 1, ...].fill_(0) + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "copy", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "copy", "none") else: - dkv.copy_(dkv_) - elif causal: - # not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dkv.add_(dkv_) - else: # i == 0 - dkv.copy_(dkv_) - else: - # not fp8 and not causal - if i == 0: - dkv.copy_(dkv_) - else: # i > 0 - dkv.add_(dkv_) + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "none") + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) - dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dq_fp8, fake_dtype=torch.float32, internal=True - ) - - if ctx.enable_mla: - # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] - dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) - dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) - dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dk_fp8, fake_dtype=torch.float32, internal=True - ) - dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dv_fp8, fake_dtype=torch.float32, internal=True - ) - dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] - dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] - else: - if ctx.qkv_format in ["bshd", "sbhd"]: - # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or - # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] - dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dkv_fp8, fake_dtype=torch.float32, internal=True + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + + dq = dq_buffer + if ctx.fp8_recipe.delayed(): + # [cp, b, 2, sk//2, h, d] or [cp, 2, sk//2, b, h, d] + dk = dkv_recv_buffer[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv = dkv_recv_buffer[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dq, dk, dv = [ + ctx.dQKV_quantizer.create_tensor_from_data( + x, fake_dtype=bwd_nominal_dtype, internal=ctx.dQKV_quantizer.internal + ) + for x in [dq, dk, dv] + ] + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + des_nominal_dtype=torch.float32, ) - dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] - dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if causal: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) - dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - dq = dq.view(-1, *dq.shape[-3:]) - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - else: - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + if ctx.fp8_recipe.float8_current_scaling(): + dk = dkv[: ctx.k_numel].view(ctx.k_shape) + dv = dkv[ctx.k_numel :].view(ctx.v_shape) + + if causal and ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + dim = ctx.qkv_format.index("s") + dq, dk, dv = [x.view(*x.shape[:dim], -1, *x.shape[dim + 2 :]) for x in [dq, dk, dv]] if ctx.qkv_format == "thd" and not ctx.use_fused_attention: dq[cu_seqlens_q_padded[-1] :].fill_(0) - if ctx.enable_mla: - dk[cu_seqlens_kv_padded[-1] :].fill_(0) - dv[cu_seqlens_kv_padded[-1] :].fill_(0) - else: - dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - assert torch.uint8 not in [dq.dtype, dkv.dtype] - if ctx.enable_mla: - dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] - else: - dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] - if not ctx.enable_mla: - dk, dv = dkv[0], dkv[1] + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + + if ctx.fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) if cp_size_a2a > 1: + if ctx.fp8 and ctx.is_input_fp8: + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2741,20 +2477,21 @@ def backward(ctx, dout): ctx.cp_stream, False, ) + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] if ctx.qkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if attn_dbias is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) - # converting torch.uint8 to float8tensor - if ctx.fp8 and ctx.is_input_fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + nvtx_range_pop(f"{nvtx_label}") return ( None, @@ -2783,6 +2520,8 @@ def backward(ctx, dout): None, None, None, + None, + None, ) @@ -2912,22 +2651,22 @@ def forward( else: cu_seqlens_q_padded = None - # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) @@ -2947,8 +2686,8 @@ def forward( for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( @@ -2970,7 +2709,7 @@ def forward( k.shape[1], max_seqlen_kv_, k.device ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( @@ -3106,17 +2845,17 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) @@ -3157,8 +2896,8 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], @@ -3166,7 +2905,7 @@ def backward(ctx, dout): ) max_seqlen_kv = seq_end_idx - seq_start_idx k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) @@ -3239,7 +2978,7 @@ def backward(ctx, dout): dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ x.movedim(seq_dim, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] @@ -3258,13 +2997,13 @@ def backward(ctx, dout): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) @@ -3335,6 +3074,7 @@ def forward( use_flash_attn_3, softmax_type, softmax_offset, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3342,7 +3082,6 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3406,32 +3145,37 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + q_fp8, k_fp8, v_fp8 = (None, None, None) if fp8: if use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] + else: + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: @@ -3448,24 +3192,18 @@ def forward( softmax_offset, 1, cp_size, cp_group, cp_stream, True ) - if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] - + out_fp8 = None + out_f16 = None batch_size = q.shape[batch_dim] + q_part, k_part, v_part = q, k, v + out_part = None if use_fused_attention: - q_part, k_part, v_part = q, k, v if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v, fake_dtype=qkv_dtype, internal=True - ) - out, aux_ctx_tensors = fused_attn_fwd( + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + out_, aux_ctx_tensors = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3474,7 +3212,7 @@ def forward( q_part, k_part, v_part, - qkv_dtype, + fwd_nominal_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -3489,8 +3227,24 @@ def forward( softmax_type=softmax_type, softmax_offset=softmax_offset, ) - if fp8: - out = out._data + if isinstance(out_, Float8Tensor): + out_fp8 = out_ + out_ = out_._data + if is_bwd_fp8 and not ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): + out_part = out_fp8 + else: + out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ + out_part = out_ + if ( + fp8 + and is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_part = O_quantizer(out_) else: fa_forward_args_thd = get_fa_args( True, @@ -3502,67 +3256,67 @@ def forward( max_seqlen_kv=max_seqlen_kv, ) fa_outputs = flash_attn_fwd( - q, - k, - v, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: - out, softmax_lse = fa_outputs[4], fa_outputs[5] + out_, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not use_flash_attn_3 else None else: - out, softmax_lse = fa_outputs[0], fa_outputs[1] + out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] + out_part = out_ - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) - out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) + out_ = flash_attn_a2a_communicate( + out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) + # [b*s, h, d] -> [b, s, h, d] + out_ = out_.view(batch_size, -1, *out_.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + # [s*b, h, d] -> [s, b, h, d] + out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if fp8: - if is_output_fp8: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=False - ) - out_ret = out_fp8 - out = out_fp8._data - else: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=True - ) - out_f16 = out_fp8.dequantize(dtype=qkv_dtype) - out_ret = out_f16 + if fp8 and use_fused_attention: + if fp8_recipe.float8_current_scaling(): + out_f16 = out_ + if is_output_fp8: + out_fp8 = O_quantizer(out_) + if fp8_recipe.delayed(): + out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) + if not is_output_fp8: + out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) else: - out_ret = out + out_f16 = out_ - if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - else: - if is_input_fp8: - q_save, k_save, v_save = q, k, v - else: - q_save, k_save, v_save = q_f16, k_f16, v_f16 - if is_output_fp8: - out_save = out + out_ret = out_fp8 if is_output_fp8 else out_f16 + + ctx.fp8 = fp8 and is_bwd_fp8 + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + if ctx.fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) else: - out_save = out_f16 + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) + f16_tensors = (q_part, k_part, v_part, out_part) + else: + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - k_save, - v_save, - out_save, + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -3571,6 +3325,7 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.out_shape = out_ret.shape ctx.batch_size = batch_size ctx.cp_group = cp_group @@ -3585,14 +3340,14 @@ def forward( ctx.deterministic = deterministic ctx.window_size = window_size ctx.use_fused_attention = use_fused_attention - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 ctx.softmax_type = softmax_type - ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -3616,6 +3371,10 @@ def backward(ctx, dout): cp_size = get_distributed_world_size(ctx.cp_group) ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, q, k, v, @@ -3626,23 +3385,21 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") - dout_dtype = dout.dtype + bwd_nominal_dtype = ctx.fwd_nominal_dtype + dqkv_te_dtype = None fused_attn_backend = None - fused_attn_dqkv_dtype = None + dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: + if not isinstance(dout, QuantizedTensorBase): dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dout_fp8 = dout + dqkv_te_dtype = dout._fp8_dtype dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3652,44 +3409,23 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None: - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - dout = dout._data - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - k = ctx.QKV_quantizer.create_tensor_from_data( - k, fake_dtype=ctx.qkv_dtype, internal=True - ) - v = ctx.QKV_quantizer.create_tensor_from_data( - v, fake_dtype=ctx.qkv_dtype, internal=True - ) - q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] + if isinstance(dout, QuantizedTensorBase): + dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) + dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) + else: + dout = dout.view(*ctx.out_shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) - out, dout = flash_attn_a2a_communicate( - [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - out = ctx.O_quantizer.create_tensor_from_data( - out, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - out = out.dequantize(dtype=ctx.qkv_dtype) - dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -3730,30 +3466,14 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: - q_part = q - k_part = k - v_part = v - out_part = out - dout_part = dout - + q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - + q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_part = out + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -3764,8 +3484,8 @@ def backward(ctx, dout): v_part, out_part, dout_part, - dout_dtype, - fused_attn_dqkv_dtype, + bwd_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, @@ -3780,10 +3500,9 @@ def backward(ctx, dout): **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) - if ctx.fp8: - dq = dq._data - dk = dk._data - dv = dv._data + if isinstance(dq, Float8Tensor): + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] @@ -3813,7 +3532,7 @@ def backward(ctx, dout): **fa_backward_kwargs, ) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False ) @@ -3835,17 +3554,22 @@ def backward(ctx, dout): ) if ctx.fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data( - dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dk = ctx.dQKV_quantizer.create_tensor_from_data( - dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dv = ctx.dQKV_quantizer.create_tensor_from_data( - dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - if not ctx.is_input_fp8: - dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] + if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ctx.fp8_recipe.delayed(): + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] + if not ctx.is_input_fp8: + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + ) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -3876,6 +3600,7 @@ def backward(ctx, dout): None, None, d_softmax_offset, + None, ) @@ -3910,6 +3635,8 @@ def attn_forward_func_with_cp( use_flash_attn_3=False, softmax_type="vanilla", softmax_offset=None, + fp8_output=False, + layer_number=1, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3973,10 +3700,15 @@ def attn_forward_func_with_cp( """ if cp_comm_type == "a2a+p2p": - assert isinstance( - cp_group, list - ), "Hierarchical CP implementation needs multi-level CP groups!" - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + isinstance(cp_group, list) and len(cp_group) == 2 + ), "CP implementation a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert ( + attn_bias_type == "no_bias" + ), f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!" if get_distributed_world_size(cp_group[0]) == 1: cp_group = cp_group[1] cp_comm_type = "p2p" @@ -4064,6 +3796,8 @@ def attn_forward_func_with_cp( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -4082,6 +3816,7 @@ def attn_forward_func_with_cp( use_flash_attn_3, softmax_type, softmax_offset, + fp8_output, ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index f72cd6926..a19d08ae5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -14,8 +14,22 @@ from torch.nn.parameter import Parameter import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Format, + Recipe, + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.fp8 import ( + get_fp8_te_dtype, + FP8GlobalStateManager, + RecipeState, + DelayedScalingRecipeState, + MXFP8BlockScalingRecipeState, + Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, +) from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -73,6 +87,67 @@ "_alibi_bias_require_update": False, } +""" +This feature is **experimental** and subject to change. + +Some models may use different FP8 recipes for their linear layers and attention layers. To support this, +users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer, +or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention +layers as follows. + ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| Linear | Attention | Configuration | ++===================+===========+===================================================================================+ +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); | +| | | export NVTE_DPA_FP8_RECIPE="F16" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); | +| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); | +| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| +| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); | +| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | +| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); | +| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | +| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +""" +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} +_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")] +_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent") +_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")) +_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1" + + __all__ = ["DotProductAttention"] @@ -462,6 +537,231 @@ def set_context_parallel_group( self.cp_stream = cp_stream self.cp_comm_type = cp_comm_type + def init_fp8_metadata(self, num_gemms: int = 1) -> None: + """ + Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support. + Initialize fp8 related metadata and tensors during fprop. + """ + _original_recipe = self.fp8_meta.get("recipe", None) + + # global recipe set in fp8_autocast() + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + + # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to + # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. + # + # fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers + # -------------------------------------------------------------------------------------------- + # DelayedScaling (DS) | unset | DS | all DS + # Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP + # x={DS, CS} | y | refer to row x=y | refer to row x=y + fp8_recipe_dpa = fp8_recipe + fp8_recipes = fp8_recipe + if _dpa_fp8_recipe == "F16": + # ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False + fp8_recipe.fp8_dpa = False + fp8_recipe.fp8_mha = False + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + fp8_recipe, + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in ( + "", + "Float8CurrentScaling", + ): + # use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format + # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP + fake_recipes = [ + Float8CurrentScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False + if not fp8_recipe_dpa.float8_per_tensor_scaling(): + assert not ( + fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha + ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + + # reduce over TP+CP groups; expect fp8_group to be set up so + # assume attention uses the same fp8_group as GEMMs + fp8_group = FP8GlobalStateManager.get_fp8_group() + + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self.fp8 = FP8GlobalStateManager.is_fp8_enabled() + self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration + self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.fp8_parameters or fp8_enabled: + self.fp8_meta["global_recipe"] = fp8_recipe + self.fp8_meta["local_recipes"] = ( + fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes] + ) + + if self.fp8_parameters or fp8_enabled: + if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]: + # FP8 init has already been run and recipe is the same, don't do anything. + return + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(fp8_recipes) + + if fp8_enabled: + # Set FP8 and other FP8 metadata + self.fp8_meta["num_gemms"] = num_gemms + self.fp8_meta["fp8_group"] = fp8_group + + # Set FP8_MAX per tensor according to recipe + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + + # Allocate scales and amaxes + self.init_fp8_meta_tensors(fp8_recipes) + self.fp8_initialized = True + + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + + def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: + """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe): + recipe = [recipe] + fp8_recipe_dpa = recipe[-1] + fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + # Return early if recipe state matches recipe + if self.fp8_meta_tensors_initialized: + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd) + return + if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return + if fp8_recipe_dpa.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return + if fp8_recipe_dpa.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return + + # When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers. + # See table above in init_fp8_metadata for more detail. + num_gemms = [2, 1] if len(recipe) == 2 else [3] + # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and + # 2 (grad_output and grad_input) for bwd + num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms] + + # Initialize recipe state and quantizers + recipe_states = [ + RecipeState.create( + recipe[i], + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors[i], + ) + for i in range(len(recipe)) + ] + + self.fp8_meta[fp8_meta_tensor_key] = ( + recipe_states[-1] if len(recipe) == 2 else recipe_states[0] + ) + self.quantizers[fp8_meta_tensor_key] = [] + for recipe_state in recipe_states: + self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + @no_torch_dynamo(recursive=False) def forward( self, @@ -485,6 +785,7 @@ def forward( fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -657,6 +958,8 @@ def forward( pad_between_seqs: Optional[bool], default = `None` If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If true, there are padding tokens between individual sequences in a packed batch. + fp8_output: Optional[bool], default = `False` + Whether to enforce output to be in FP8 or not. """ with torch.cuda.device(query_layer.device), self.prepare_forward( @@ -693,6 +996,8 @@ def forward( tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + else: + fp8_output = False # checks for q/k/v shapes assert ( @@ -1092,6 +1397,7 @@ def forward( quantizers=self.quantizers, inference_params=inference_params, flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, ) if use_fused_attention: @@ -1140,6 +1446,7 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + fp8_output=fp8_output, ) return self.fused_attention( query_layer, @@ -1169,6 +1476,7 @@ def forward( pad_between_seqs=pad_between_seqs, inference_params=inference_params, softmax_offset=softmax_offset, + fp8_output=fp8_output, ) from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled @@ -1180,6 +1488,7 @@ def forward( ) if use_unfused_attention: + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -1198,6 +1507,10 @@ def forward( alibi_slopes=alibi_slopes, inference_params=inference_params, softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return self.unfused_attention( _alibi_cache, @@ -1215,5 +1528,9 @@ def forward( alibi_slopes=alibi_slopes, inference_params=inference_params, softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 72c595e3f..ea7b0e876 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -17,6 +17,7 @@ from packaging.version import Version as PkgVersion import torch +import torch.distributed as dist import torch.nn.functional as F import transformer_engine_torch as tex import transformer_engine as te @@ -32,11 +33,13 @@ META_DO, META_S, META_DP, - META_O_CP, - META_DQKV_CP, ) from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) from transformer_engine.pytorch.fp8 import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -44,6 +47,8 @@ from transformer_engine.pytorch.utils import ( get_device_compute_capability, get_cudnn_version, + SplitAlongDim, + combine_tensors, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -54,6 +59,9 @@ # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +# print quantizer info for a particular layer on a particular rank +_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1")) +_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0")) _cu_seqlens_cache = {} @@ -350,8 +358,31 @@ def get_attention_backend( field.name: getattr(attention_params, field.name) for field in fields(attention_params) } run_config.update(attention_params_dict) + # Add FP8 environment variables to config if fp8: + # all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd + run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1")) + # switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling" + _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") + run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe + if _dpa_fp8_recipe != "": + # config new recipe if switched + run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID") + run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv( + "NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent" + ) + run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int( + os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1") + ) + run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int( + os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") + ) + # UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow + run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") + ) logger.debug("Running with config=%s", run_config) # The following sections check if `FlashAttention` supports the provided attention params, @@ -431,8 +462,20 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") - use_unfused_attention = False + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + if not allow_emulation: + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") + use_unfused_attention = False + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if ( + use_fused_attention + and fp8_recipe.float8_current_scaling() + and device_compute_capability < (10, 0) + ): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size @@ -1875,11 +1918,10 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): +def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: - num_of_nones = 8 if cp_specific_quantizers else 6 - return [None] * num_of_nones + return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1888,6 +1930,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True dQKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1897,22 +1940,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True - dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] - dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_CP_quantizer.internal = True - O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] - O_CP_quantizer.set_usage(rowwise=True, columnwise=False) - - if cp_specific_quantizers: - return ( + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + +def print_quantizers( + label, + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, +): + """Print the type and scale/amax of attention quantizers""" + _to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2 + if ( + _to_print + and _print_layer == layer_number + and ( + not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank) + ) + ): + names = [ + "QKV_quantizer", + "S_quantizer", + "O_quantizer", + "dO_quantizer", + "dP_quantizer", + "dQKV_quantizer", + ] + quantizers = [ QKV_quantizer, - O_quantizer, - O_CP_quantizer, S_quantizer, - dQKV_quantizer, - dQKV_CP_quantizer, + O_quantizer, dO_quantizer, dP_quantizer, - ) + dQKV_quantizer, + ] + if "forward" in label: + names = names[:3] + quantizers = quantizers[:3] + if "backward" in label: + names = names[3:] + quantizers = quantizers[3:] + for i, q in enumerate(quantizers): + type_str = "" + if q is None: + type_str = "None" + elif isinstance(q, Float8Quantizer): + type_str = "DS" + elif isinstance(q, Float8CurrentScalingQuantizer): + type_str = "CS" + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) - return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): + """Combine q,k,v based on qkv_layout and quantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + src_nominal_dtype = q.dtype + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = combine_tensors([q, k, v], dim) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv = combine_tensors([k, v], dim) + tensors = [q, kv] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, kv_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True) + case 3: + tensors = [q, k, v] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype) + for x in [q_data, k_data, v_data] + ] + + return q_fp8, k_fp8, v_fp8 + + +def combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None +): + """Combine q,k,v based on qkv_layout and dequantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + src_nominal_dtype = q_fp8.dtype + else: + assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" + if des_nominal_dtype is None: + des_nominal_dtype = src_nominal_dtype + + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv_data = combine_tensors([q_data, k_data, v_data], dim) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv_data = combine_tensors([k_data, v_data], dim) + tensors = [q_data, kv_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + k, v = SplitAlongDim.apply(kv, dim, [1, 1], True) + case 3: + tensors = [q_data, k_data, v_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + return q, k, v diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 790d78c75..b2f1ff1ac 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Multi-head Attention.""" +import os import collections from typing import Callable, List, Optional, Tuple, Union import torch @@ -31,7 +32,13 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + +# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast(). +# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" +# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1" +_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1" class MultiheadAttention(torch.nn.Module): @@ -570,10 +577,12 @@ def set_context_parallel_group( self.cp_size = get_distributed_world_size(cp_group) self.cp_rank = get_distributed_rank(cp_group) elif isinstance(cp_group, list): - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" assert ( cp_comm_type == "a2a+p2p" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + assert ( + len(cp_group) == 2 + ), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_size_p2p = get_distributed_world_size(cp_group[1]) @@ -730,10 +739,22 @@ def forward( # Query, Key, and Value # ====================== - fp8_mha = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.get_fp8_recipe().fp8_mha - ) + fp8 = FP8GlobalStateManager.is_fp8_enabled() + if _dpa_fp8_recipe == "": + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_dpa = fp8_recipe.fp8_dpa + fp8_mha = fp8_recipe.fp8_mha + float8_current_scaling = fp8_recipe.float8_current_scaling() + else: + fp8_dpa = _dpa_fp8_recipe_dpa + fp8_mha = _dpa_fp8_recipe_mha + float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe + qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling + # DPA: always produce FP8 output when fp8=True to take advantage of the O amax + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + # Proj Gemm: match DPA output except for Float8CurrentScaling + proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None if self.attention_type == "self": @@ -742,7 +763,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -752,7 +773,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) num_queries_per_key_value = ( @@ -806,7 +827,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.qkv_weight_interleaved: @@ -861,7 +882,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -871,7 +892,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) # [sq, b, hp] --> [sq, b, np, hn] @@ -972,6 +993,7 @@ def forward( fast_zero_fill=fast_zero_fill, inference_params=inference_params, pad_between_seqs=pad_between_seqs, + fp8_output=dpa_fp8_output, ) # =================== @@ -980,7 +1002,7 @@ def forward( projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, - fp8_grad=isinstance(context_layer, QuantizedTensor), + fp8_grad=proj_fp8_grad, ) if self.return_bias: diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index df2f5d1ca..94a12c4a0 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -109,9 +109,6 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 def fused_attn_fwd( diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index c94bd0d2a..978bee52d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -201,7 +201,7 @@ class Float8CurrentScalingQuantizer : public Quantizer { * amax to be initialized to zero. */ std::pair create_unquantized_tensor_with_amax( - const std::vector& shape, DType dtype); + const std::vector& shape, DType dtype, std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 4edc6d81e..cc33f2a89 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -78,6 +78,11 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data); + std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 5db9dd73d..344bc4ab0 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -53,6 +53,47 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( return fused_attention_backend; } +// helper function for S and dP quantizers +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data) { + std::unique_ptr T_quantizer = convert_quantizer(quantizer); + TensorWrapper te_T; + py::object py_T; + if (quantizer.is_none()) { + // high precision + auto *none_quantizer = dynamic_cast(T_quantizer.get()); + if (data.has_value()) { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype); + } + } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + // delayed scaling; this helps initialize scale_inv + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // current scaling + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } + } + return {std::move(te_T), std::move(py_T)}; +} + // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, @@ -66,44 +107,30 @@ std::vector fused_attn_fwd( py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, size_t rng_elts_per_thread) { - TensorWrapper te_Q, te_K, te_V, te_O, te_S; - auto none = py::none(); - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); + // create QKV tensor wrappers + TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); - - // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + // create S tensor + TensorWrapper te_S; + py::object py_S; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + + // create O tensor + TensorWrapper te_O; + py::object py_O; + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); - std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - // create output tensor O - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; - py::object o_python, s_python; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // Initialize FP8 tensor with scale-inverse - auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - } - auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); // construct NVTE tensors TensorWrapper te_Bias; @@ -114,11 +141,12 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - te_O.zero_(at::cuda::getCurrentCUDAStream()); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if ((h * d) % block_size == 0) { + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + te_O.zero_(at::cuda::getCurrentCUDAStream()); + } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { @@ -181,7 +209,8 @@ std::vector fused_attn_fwd( auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); + auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options); philox_unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); @@ -210,7 +239,7 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; - output_tensors.push_back(o_python); + output_tensors.push_back(py_O); auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), @@ -280,50 +309,44 @@ std::vector fused_attn_bwd( const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, py::handle dp_quantizer, py::handle dqkv_quantizer) { auto none = py::none(); - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + + // create QKV, O, dO tensor wrappers + TensorWrapper te_Q, te_K, te_V, te_O, te_dO; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); te_O = makeTransformerEngineTensor(O, none); te_dO = makeTransformerEngineTensor(dO, none); - // qkv type from the te_Q - std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); - const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - py::object s_python, dp_python; - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); - } + // create S and dP tensors + TensorWrapper te_S, te_dP; + py::object py_S, py_dP; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + std::tie(te_dP, py_dP) = + quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt); + // create dQ, dK, dV tensors + TensorWrapper te_dQ, te_dK, te_dV; + py::object py_dQ, py_dK, py_dV; + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; - auto d_v = v_shape[v_shape.size() - 1]; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = d_v; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); at::Tensor dQ, dK, dV, dQKV, dKV; - py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + options = options.dtype(torch::kUInt8); + } + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + options = options.dtype(fake_dtype); + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: @@ -396,39 +419,27 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_dQ, py_dQ) = - fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); - std::tie(te_dK, py_dK) = - fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); - std::tie(te_dV, py_dV) = - fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); - } else { - auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); - std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); - } + + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); // construct NVTE tensors - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } } - - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); @@ -605,7 +616,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 int seq_dim = tensor.dim() == 3 ? 0 : 1; - int batch = cu_seqlens.size(0) - 1; int num_heads = tensor.size(seq_dim + 1); int dim_per_head = tensor.size(seq_dim + 2); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); @@ -769,8 +779,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t NVTE_CHECK(world_size > 0); NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); - int batch = cu_seqlens.size(0) - 1; - std::vector shape = {total_tokens / world_size}; at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); @@ -808,7 +816,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, **************************************************************************************************/ at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { - int max_seq_len = tensor.size(1); int h = tensor.size(2); int d = tensor.size(3); std::vector shape = {t, h, d}; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index e9647b44f..2c1edae4c 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -37,7 +37,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob // Convert input tensor to C++ object auto input_contiguous = tensor.contiguous(); - const auto input_cpp = makeTransformerEngineTensor(input_contiguous); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + // Set amax if use_existing_amax = true (only valid for CS) + bool use_existing_amax = false; + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + use_existing_amax = quantizer.attr("use_existing_amax").cast(); + if (use_existing_amax) { + const at::Tensor &amax = quantizer.attr("amax").cast(); + input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + } // Initialize output tensor TensorWrapper output_cpp; @@ -57,7 +68,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob } // Perform quantization - quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + if (use_existing_amax) { + auto *quantizer_cs = dynamic_cast(quantizer_cpp.get()); + quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp); + } else { + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + } return output_py; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 2abe9614e..8470466ae 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -390,9 +390,13 @@ std::pair Float8CurrentScalingQuantizer::create_tenso std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, - DType dtype) { + DType dtype, + std::optional data) { amax.zero_(); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); return {std::move(out_cpp), std::move(out_py)}; diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index a75a03bfa..15017913f 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -970,7 +970,9 @@ def make_quantizers(self) -> list: from .tensor.float8_tensor import Float8CurrentScalingQuantizer return [ - Float8CurrentScalingQuantizer(self.dtype, device=self.device) + Float8CurrentScalingQuantizer( + self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales + ) for i in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 1524584aa..18750d039 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -215,6 +215,8 @@ class Float8CurrentScalingQuantizer(Quantizer): amax: torch.Tensor """FP8 datatype""" dtype: TE_DType + """amax update options""" + use_existing_amax: bool """amax reduction options""" with_amax_reduction: bool amax_reduction_group: Optional[dist_group_type] @@ -229,6 +231,7 @@ def __init__( *, rowwise: bool = True, columnwise: bool = True, + use_existing_amax: bool = False, with_amax_reduction: bool = False, amax_reduction_group: Optional[dist_group_type] = None, force_pow_2_scales: bool = False, @@ -238,6 +241,7 @@ def __init__( self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype + self.use_existing_amax = use_existing_amax self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales From 7fa0f5541bff9df574c7b7c7c6b6cd46e9009b57 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 30 Sep 2025 16:51:43 -0700 Subject: [PATCH 016/141] [Pytorch] Support for Swiglu Activation used in GPT OSS (#2161) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Test working as I think it should work Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> Signed-off-by: Varun Thumbe minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe fix linting error Signed-off-by: Varun Thumbe * initial draft of changes to get GPT oss based swiglu integrated, gated kernels needs to be fixed Signed-off-by: Varun Thumbe * redundant implementation for the pytorch to te hook up, refactoring to be done later Signed-off-by: Varun Thumbe * all gated kernels modified, pytest working for oss swiglu Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix the merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Signed-off-by: Varun Thumbe Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: Varun Thumbe [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe Update list of authorized CI users (#2152) Signed-off-by: Tim Moon Signed-off-by: Varun Thumbe a bit of cleanup Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * accidentally had removed some activations, minor bug in the templated function Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262513 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262476 +0000 parent de9ef2fe450daae0d4ea1b647a37219f72814f66 author Varun Thumbe 1757373536 +0000 committer Varun Thumbe 1758262304 +0000 merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] fix cross entropy vanishing gradients (#2139) * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen Update list of authorized CI users (#2152) Signed-off-by: Tim Moon Fused RoPE with combined QKV input. (#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernels Signed-off-by: Vasudevan Rengasamy * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao * fix lint Signed-off-by: Xin Yao --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * accidentally removed the copyright Signed-off-by: Varun Thumbe * fix linting issue Signed-off-by: Varun Thumbe * minor issue in comments Signed-off-by: Varun Thumbe * Commit is for another PR Signed-off-by: vthumbe1503 * revert changes since this belongs to another PR Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert change back since belongs to another PR Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changes belong to another PR Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert changes here Signed-off-by: vthumbe1503 Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162) * add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion Signed-off-by: tongliu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu Co-authored-by: tongliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [JAX] Scale swizzling via JAX transpose op (#2163) * add swizzle in jax Signed-off-by: Phuong Nguyen * added outer_impl Signed-off-by: Phuong Nguyen * clean up FFI Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Extract cpp distributed tests into a separate project (#2165) * Extract cpp distributed tests into a separate project Signed-off-by: Vladimir Cherepanov * Remove obsolete exclusion Signed-off-by: Vladimir Cherepanov * Run L1_cpp_distributed tests if at least 4 GPUs Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129) * test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell * assert line change Signed-off-by: Jonathan Mitchell * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jonathan Mitchell Co-authored-by: Jonathan Mitchell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sudhakar Singh * address review comments Signed-off-by: Varun Thumbe * cleanup Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix linting error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [PyTorch Debug] Fix issue with negative underflow% stat. (#2107) * fix underflows log issue Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * Address review comments, fix mxfp8 kernel bug: was not passing clamped swiglu parameter correctly Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as Norm outputs' and activation compuations. Signed-off-by: Ming Huang * Adding knob to control norm output precision. Signed-off-by: Ming Huang * Removing the knob and applying lower-precision norm with current-scaling only. Signed-off-by: Ming Huang * Fix the error when quantizer==None Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang [PyTorch] Support activation CPU offloading in fusible ops (#2158) * Add CPU offloading logic to ops. Fix test to compute dgrad. Signed-off-by: Tim Moon * Make sure grads are contiguous in op backwards Signed-off-by: Tim Moon * Add op-based MLP to CPU offloading tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle different weight cache behavior on Hopper/Blackwell Add MXFP8 to CPU offload tests. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove MXFP8 test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174) * Do not use norm fwd + amax fusion if cudnn backend is requested Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Read envirornment vairable directly to avoid include error Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Fix unjoined comm stream in UB communicator (#2160) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> FP8 Output Quantization for GEMM (#2123) * Test working as I think it should work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> address review comments Signed-off-by: Varun Thumbe * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> TE Gemma tutorial attempt#2 (#1839) * add tutorial files and other local changes Signed-off-by: Sudhakar Singh * remove extraneous code for easy debu Signed-off-by: Sudhakar Singh * make cuda graphs work with non-paged and paged attention Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * perf imp for kv cache ops Signed-off-by: Sudhakar Singh * add code for calibration Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes to make quantizers work with fp8_calibration Signed-off-by: Sudhakar Singh * avoid reindexing from python side Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename variable from previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use quantizer only if needed Signed-off-by: Sudhakar Singh * functionality of the tutorial tested and perf checked Signed-off-by: Sudhakar Singh * remove files and update headers/licenses Signed-off-by: Sudhakar Singh * update header/license Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tutorial for review Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make weights downloadable on the fly; remove extra print statements Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and update comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comma back, typo Signed-off-by: Sudhakar Singh * sequence_start_positions should be None for training Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add paged attention numberes and update requirements.txt file Signed-off-by: Sudhakar Singh * more fixes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make tutorial work on blackwell Signed-off-by: Sudhakar Singh * remove gemma FT tutorial for now Signed-off-by: Sudhakar Singh * fixing the headings placement and rewording attention -> kv caching Signed-off-by: Sudhakar Singh * fixes from comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the images Signed-off-by: Sudhakar Singh * misc fixes Signed-off-by: Sudhakar Singh * add more comments to te_gemma.py and cleanup utils.py Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more information about the hierarchy of the classes used in the tutorial Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add better cuda graphs picture Signed-off-by: Sudhakar Singh * addd updated cuda graphs pictures Signed-off-by: Sudhakar Singh * add illustrated cuda graphs Signed-off-by: Sudhakar Singh * fix Signed-off-by: Sudhakar Singh * small fixes in documentation Signed-off-by: Sudhakar Singh * add torch.no_grad() to force reduced memory usage Signed-off-by: Sudhakar Singh * some fixes from recent comments Signed-off-by: Sudhakar Singh * more fixes from remaining comments Signed-off-by: Sudhakar Singh * add te_rope_emb to class desc Signed-off-by: Sudhakar Singh * fix tutorial wording; add calibration fix to grouped_linear.py Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Fix memory overhead of linear layer when all gather from sequence parallel (#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon --------- Signed-off-by: Yuzhong Wang Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Fix incorrect TP rank calculation when using data parallel (#2179) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> [Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045) * feat: add cutlass group gemm support Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor multi tensor gemm interface Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor nvte_multi_stream_cublas_gemm func and add license info Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add unit test for cutlass group gemm Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add cutlass support type protect Signed-off-by: Min Yang * add tests and fix lint Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: fix unit tests error Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: refactor host workspace malloc Signed-off-by: Min Yang * update cutlass Signed-off-by: Xin Yao * update cutlass Signed-off-by: Xin Yao * further relex threshold and add a env var to warn fall back Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Min Yang Signed-off-by: Xin Yao Signed-off-by: alan yang <89962857+cassiewilliam@users.noreply.github.com> Co-authored-by: Min Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen [PyTorch] Support FA3 for MLA and with CP (#1907) feature(FA3,MLA,CP): 1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward 2. Update get_attention_backend method because FA3 support MLA now 3. Add CP MLA support for FA3 4. Add unit tests for FA3 MLA CP 5. Update attention doc Signed-off-by: zhujian Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185) * Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix for cuDNN version condition check Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use limit=0.75 in clamped SwiGLU test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * accidentally removed a line while resolving merge conflict Signed-off-by: Varun Thumbe * match pytorch implementation: dclamp should be 1 for borders of clamping limits as well Signed-off-by: Varun Thumbe * fix dswiglu quantization fusion bug Signed-off-by: Varun Thumbe * pass param by reference as much as possible Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * float should rather be bool: fix by copilot Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * { missed in activation.cpp Signed-off-by: Varun Thumbe * minor formatting change Signed-off-by: Varun Thumbe * nvfp4 change Signed-off-by: Varun Thumbe --------- Signed-off-by: Varun Thumbe Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Signed-off-by: vthumbe1503 Signed-off-by: Jonathan Mitchell Signed-off-by: Pawel Gadzinski Signed-off-by: Kshitij Lakhani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Jonathan Mitchell Co-authored-by: Sudhakar Singh --- tests/pytorch/test_fusible_ops.py | 74 +++++++++ .../common/activation/activation_template.h | 10 +- transformer_engine/common/activation/gelu.cu | 12 +- transformer_engine/common/activation/relu.cu | 12 +- .../common/activation/swiglu.cu | 23 ++- .../include/transformer_engine/activation.h | 40 +++++ .../common/util/cast_gated_kernels.cuh | 146 +++++++++++------- transformer_engine/common/util/math.h | 39 ++++- .../common/util/vectorized_pointwise.h | 26 +++- transformer_engine/pytorch/csrc/extensions.h | 4 + .../pytorch/csrc/extensions/activation.cpp | 139 ++++++++++++----- .../pytorch/csrc/extensions/pybind.cpp | 6 + .../pytorch/ops/basic/__init__.py | 14 +- .../pytorch/ops/basic/activation.py | 36 +++++ 14 files changed, 459 insertions(+), 122 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 440986661..231fa64bc 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1736,6 +1736,80 @@ def test_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) + def test_clamped_swiglu( + self, + *, + out_shape: Iterable[int] = (32, 32), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, + limit: float = 0.75, + alpha: float = 1.702, + ): + # Test SwiGLU variant used in GPT OSS. + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x_glu, x_linear = x_ref.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + y_ref = out_glu * (x_linear + 1) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + recipe = make_recipe(quantization) + + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantize_backward), + te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), + te_ops.Quantize(forward=quantize_forward, backward=False), + ) + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = forward(x_test) + + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if quantized_compute and quantization == "nvfp4": + tols = dtype_tols(tex.DType.kFloat4E2M1) + elif quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 67f173a4a..1d9a3fb43 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, } template -void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { +void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = false; constexpr NVTETensor grad = nullptr; - - quantize_gated_helper(grad, input, output, stream); + quantize_gated_helper(grad, input, output, p, stream); } template -void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; constexpr bool IS_DGATED = true; - - quantize_gated_helper(grad, input, output, stream); + quantize_gated_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 0cf43007a..4949ba590 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dgelu>(grad, input, output, e, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dqgelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index a794b7315..c74fc6eee 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, drelu>(grad, input, output, e, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dsrelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 819496474..cafc48abb 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dsilu>(grad, input, output, e, stream); +} + +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream) { + NVTE_API_CALL(nvte_clamped_swiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha}; + gated_act_fn>(input, output, param, stream); +} + +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream) { + NVTE_API_CALL(nvte_clamped_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha}; + dgated_act_fn, clamped_dsilu>( + grad, input, output, param, stream); } diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 49029ed58..e50d71040 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input used in GPT OSS. + * + * See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This Gated activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS. + * + * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream); + /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 6093b54b6..ca37a2831 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -55,7 +55,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_gate, float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; @@ -161,7 +162,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; OType *out_act_sh_curr = out_act_sh + buff * buff_elems; OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; - #pragma unroll for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; @@ -171,6 +171,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); @@ -178,18 +184,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + if (act_elt <= p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; + } else { + dact_x = 0.0f; + } } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); @@ -197,7 +212,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) amax = fmaxf(amax, fabsf(after_dact)); amax = fmaxf(amax, fabsf(after_dgate)); } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; + const float after_act = ActOP(act_elt, p) * gate_elt; out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); amax = fmaxf(amax, fabsf(after_act)); } @@ -300,7 +315,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -476,25 +491,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); float after_act_elt; float after_gate_elt; - + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { @@ -720,27 +747,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate.data.elt[e]); float after_act_elt; float after_gate_elt; - + bool dgate_elt = true; + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; after_act_rowwise[j] = after_act_elt; after_gate_rowwise[j] = after_gate_elt; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; after_act_rowwise[j] = after_act_elt; } @@ -885,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -960,15 +999,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) } template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1099,7 +1137,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1116,7 +1155,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: @@ -1125,7 +1165,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out OType, true, true, THREADS_PER_CHUNK_NON_COLWISE>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - mxfp8_kernel::cast_mxfp8_gated_kernel <<>>( @@ -1133,7 +1172,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; }); // NOLINT(*) @@ -1141,12 +1180,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } template -void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); NVTE_CHECK(input.flat_last_dim() % 2 == 0, "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); @@ -1168,7 +1204,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), {}, stream); + output->flat_last_dim(), p, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1177,7 +1213,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, + cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); CheckOutputTensor(*output, "dgated_act_output"); @@ -1206,7 +1243,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt reinterpret_cast(output->scale.dptr), reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), {}, stream); + grad.flat_last_dim(), p, stream); } else { NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); }); // NOLINT(*) @@ -1215,7 +1252,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, cudaStream_t stream) { constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); @@ -1255,17 +1292,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu if (is_delayed_tensor_scaling(output->scaling_mode)) { if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, stream); + cast_fp8_gated(grad, gated_input, output, p, stream); } else { if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); + cast_dgated(grad, gated_input, output, p, stream); } else { - cast_gated(gated_input, output, stream); + cast_gated(gated_input, output, p, stream); } } } else if (is_mxfp8_scaling(output->scaling_mode)) { if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); + cast_mxfp8_gated(grad, gated_input, output, p, stream); } else { NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", "by 32, got input of shape ", gated_input.data.shape); @@ -1281,7 +1318,7 @@ namespace detail { template void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - cudaStream_t stream) { + ParamOP &p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; @@ -1290,13 +1327,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, if (is_supported_by_CC_100()) { quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); + output_tensor, p, stream); } else { if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); + cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, + stream); } else { - cast_gated(gated_input_tensor, output_tensor, stream); + cast_gated(gated_input_tensor, output_tensor, p, stream); } } else { // MX scaling diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2d425d675..2f20817fb 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -11,6 +11,11 @@ namespace transformer_engine { struct Empty {}; +struct ClampedSwiGLUParam { + float limit; + float alpha = 1.702f; // Default value for QuickGELU +}; + template __device__ inline OType gelu(const IType val, const Empty&) { const float cval = val; @@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) { return s * (1.f - s); } +template +__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) { + const float cval = val; + Empty e = {}; + return cval * sigmoid(alpha * cval, e); +} + template __device__ inline OType qgelu(const IType val, const Empty& e) { + return qgelu_with_alpha(val, 1.702f); +} + +template +__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) { const float cval = val; - return cval * sigmoid(1.702f * cval, e); + Empty e = {}; + return alpha * cval * dsigmoid(alpha * cval, e) + + sigmoid(alpha * cval, e); } template __device__ inline OType dqgelu(const IType val, const Empty& e) { - const float cval = val; - return 1.702f * cval * dsigmoid(1.702f * cval, e) + - sigmoid(1.702f * cval, e); + return dqgelu_with_alpha(val, 1.702f); } template @@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) { return cval * sigmoid(cval, e); } +template +__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) { + const float cval = min(p.limit, static_cast(val)); // Clamping + return qgelu_with_alpha(cval, p.alpha); +} + template __device__ inline OType dsilu(const IType val, const Empty& e) { const float cval = val; return cval * dsigmoid(cval, e) + sigmoid(cval, e); } +template +__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) { + const bool dclamp_val = static_cast(val) <= p.limit; + const float clamp_val = min(static_cast(val), p.limit); + const float dsilu_val = dqgelu_with_alpha(clamp_val, p.alpha); + return dclamp_val ? dsilu_val : 0.0f; +} + template __device__ inline OType relu(IType value, const Empty&) { return fmaxf(value, 0.f); diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 0d667a0ec..dd6869e02 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -11,7 +11,7 @@ #include "../common.h" #include "../utils.cuh" - +#include "math.h" namespace transformer_engine { /* \brief Helper class that enables storing multiple values of type DType @@ -338,7 +338,7 @@ template void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, - const Param params, cudaStream_t stream) { + const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -372,7 +372,7 @@ template void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, OutputType *output, const fp32 *scale, fp32 *amax, - fp32 *scale_inv, const size_t N, const Param params, + fp32 *scale_inv, const size_t N, const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); @@ -431,7 +431,13 @@ __launch_bounds__(unary_kernel_threads) __global__ #pragma unroll for (int i = 0; i < nvec; ++i) { const ComputeType val = static_cast(loader0.separate()[i]); - const ComputeType val2 = static_cast(loader1.separate()[i]); + ComputeType val2 = static_cast(loader1.separate()[i]); + + if constexpr (std::is_same::value) { + // Clamp the gated value and add 1 at the end + ComputeType limit = p.limit; + val2 = std::min(std::max(-limit, val2), limit) + 1; + } ComputeType temp = static_cast(Activation(val, p) * val2); if (requires_amax) { __builtin_assume(max >= 0); @@ -532,10 +538,18 @@ __launch_bounds__(unary_kernel_threads) __global__ for (int i = 0; i < nvec; ++i) { const ComputeType grad_val = static_cast(grad_loader.separate()[i]); const ComputeType gelu_in = static_cast(input_loader0.separate()[i]); - const ComputeType gate_in = static_cast(input_loader1.separate()[i]); + ComputeType gate_in = static_cast(input_loader1.separate()[i]); + bool dgate_in = true; + + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + const ComputeType limit = p.limit; + dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp + gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; - ComputeType after_dgate = grad_val * Activation(gelu_in, p); + ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f; if (requires_amax) { __builtin_assume(max >= 0); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index cc33f2a89..d86a96959 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -205,6 +205,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); + +py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + float limit, float alpha); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index cdfb4be40..14cc084c0 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -3,7 +3,6 @@ * * See LICENSE for license information. ************************************************************************/ - #include "../extensions.h" #include "common.h" #include "pybind.h" @@ -12,10 +11,12 @@ namespace transformer_engine { namespace pytorch { namespace { +using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t); +using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t), - const at::Tensor& input, py::handle quantizer, - int shape_divisor = 1) { +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, + Args&&... args) { init_extension(); // Input tensor @@ -56,14 +57,28 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud // Compute activation in high precision, then quantize { auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); quantizer_cpp->quantize(temp_nvte, out_nvte); } break; case Impl::FULLY_FUSED: // Compute activation directly { - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), out_nvte.data(), stream); + } + }); } break; case Impl::FUSED_ACTIVATION_AMAX_FP8: @@ -73,7 +88,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); auto [temp_nvte, _] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; @@ -84,7 +106,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud static_cast(quantizer_cpp.get()); // Already checked cast is valid auto [temp_nvte, _] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; @@ -95,10 +124,9 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud return out_py; } -py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, - cudaStream_t), - const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { +template +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, + py::handle quantizer, Args&&... args) { init_extension(); // Grad output and input tensors @@ -142,8 +170,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen { auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), - at::cuda::getCurrentCUDAStream()); + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } }); quantizer_cpp->quantize(temp_nvte, grad_input_nvte); } @@ -152,7 +184,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen // Compute activation backward directly { NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + } }); } break; @@ -163,8 +200,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); auto [temp_nvte, _] = fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; @@ -175,8 +218,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen static_cast(quantizer_cpp.get()); // Already checked cast is valid auto [temp_nvte, _] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; @@ -186,90 +235,98 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen return grad_input_py; } - } // namespace /* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_gelu, input, quantizer); + return activation_helper(input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dgelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_geglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dgeglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_qgelu, input, quantizer); + return activation_helper(input, quantizer); } py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dqgelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_qgeglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dqgeglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } /* ReLU and variants */ py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_relu, input, quantizer); + return activation_helper(input, quantizer); } py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_drelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_reglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dreglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_srelu, input, quantizer); + return activation_helper(input, quantizer); } py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dsrelu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_sreglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dsreglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } - /* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_silu, input, quantizer); + return activation_helper(input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dsilu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_forward(nvte_swiglu, input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return activation_backward(nvte_dswiglu, grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); +} + +/* clamped functions */ +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { + return activation_helper(input, quantizer, 2, limit, alpha); +} + +py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + float limit, float alpha) { + return dactivation_helper(grad, input, quantizer, limit, alpha); } } // namespace pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 98f71f9a7..382adbfb0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); + m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, + "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -178,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, + "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 2c903675f..28d49bf7b 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,7 +4,19 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU +from .activation import ( + GELU, + GEGLU, + QGELU, + QGEGLU, + ReLU, + ReGLU, + SReLU, + SReGLU, + SiLU, + SwiGLU, + ClampedSwiGLU, +) from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 22779b601..8a754c638 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -28,6 +28,7 @@ "SReGLU", "SiLU", "SwiGLU", + "ClampedSwiGLU", ] @@ -392,3 +393,38 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dswiglu(*args, **kwargs) + + +class ClampedSwiGLU(_ActivationOperation): + r"""GPT-OSS + Implementation based on `GPT-OSS`__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + + Parameters + ---------- + limit: float + The clamp limit. + alpha: float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input: bool, default = False + Quantize input tensor when caching for use in the backward pass. + """ + + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): + super().__init__(cache_quantized_input=cache_quantized_input) + self.limit = limit + self.alpha = alpha + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) From ce18bee70fe98f041e91666f1935ffdd524090db Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 30 Sep 2025 16:53:01 -0700 Subject: [PATCH 017/141] [JAX] Load modules during initialize for Norm and Act primitives (#2219) Load modules during initialize Signed-off-by: Jeremy Berchtold Co-authored-by: JAX Toolbox --- transformer_engine/jax/csrc/extensions.h | 4 ++ .../jax/csrc/extensions/activation.cpp | 58 +++++++++++++++++ transformer_engine/jax/csrc/extensions/ffi.h | 15 +++++ .../jax/csrc/extensions/normalization.cpp | 63 +++++++++++++++++++ .../jax/csrc/extensions/pybind.cpp | 10 ++- 5 files changed, 148 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 92937dd46..2ab95002f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -41,16 +41,20 @@ inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == D // Activation XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, JAXX_Scaling_Mode scaling_mode, bool is_2x); // Normalization +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906b..b2b3db52c 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -148,6 +148,30 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Attr("is_2x"), FFI_CudaGraph_Traits); +Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, int64_t act_enum, + JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, + colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, + act_enum, scaling_mode, is_2x_int); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("is_2x")); + pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, JAXX_Scaling_Mode scaling_mode, bool is_2x) { @@ -410,5 +434,39 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("is_2x") .Attr("is_dbias"), FFI_CudaGraph_Traits); + +Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, + bool is_2x, bool is_dbias) { + return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, + act_input_buf, scale_buf, output_buf, colwise_output_buf, + scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, + DActLuDBiasQuantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act input + .Arg() // scale + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // amax + .Ret() // dbias + .Ret() // wkspace + .Attr("scaling_mode") + .Attr("act_enum") + .Attr("is_2x") + .Attr("is_dbias")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 852a67c6c..82f062a15 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream; using Dictionary = xla::ffi::Dictionary; constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; +constexpr auto FFI_Initialize = xla::ffi::ExecutionStage::kInitialize; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); @@ -106,5 +107,19 @@ inline static size_t te_dtype_bytes(const DType& type) { } } +template +Error_Type wrapInStreamCapture(std::function func, + cudaStream_t stream, Args... args) { + cudaGraph_t graph{}; + NVTE_CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); + + Error_Type error = func(stream, std::forward(args)...); + + NVTE_CHECK_CUDA(cudaStreamEndCapture(stream, &graph)); + NVTE_CHECK_CUDA(cudaGraphDestroy(graph)); + + return error; +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index c35bc6668..523819392 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -180,6 +180,42 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("is_2x"), FFI_CudaGraph_Traits); +Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, + Buffer_Type gamma_buf, Buffer_Type beta_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type amax_buf, Result_Type mu_buf, + Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, + bool zero_centered_gamma, double epsilon, int64_t sm_margin, + JAXX_Scaling_Mode scaling_mode, bool is_2x) { + return wrapInStreamCapture( + std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, + colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, + wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // scale + .Arg() // gamma + .Arg() // beta + .Ret() // output + .Ret() // colwise_output + .Ret() // scale_inv + .Ret() // colwise_scale_inv + .Ret() // amax + .Ret() // mu + .Ret() // rsigma + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("epsilon") + .Attr("sm_margin") + .Attr("scaling_mode") + .Attr("is_2x")); + pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, NVTE_Norm_Type norm_type, bool zero_centered_gamma, int sm_margin) { @@ -305,5 +341,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, .Attr("sm_margin"), FFI_CudaGraph_Traits); +Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type mu_buf, Buffer_Type rsigma_buf, + Buffer_Type gamma_buf, Result_Type xgrad_buf, + Result_Type wgrad_buf, Result_Type dbeta_buf, + Result_Type wkspace_buf, int64_t norm_type, + bool zero_centered_gamma, int64_t sm_margin) { + return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf, + rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf, + norm_type, zero_centered_gamma, sm_margin); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // mu + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // dbeta + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("sm_margin")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 06e2e2e00..36dd8205b 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -22,8 +22,12 @@ pybind11::dict Registrations() { pybind11::dict dict; // Activation - dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); - dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); + dict["te_act_lu_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ActLuHandler)); + dict["te_dact_dbias_quantize_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler)); // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); @@ -44,9 +48,11 @@ pybind11::dict Registrations() { // Normalization dict["te_norm_forward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); dict["te_norm_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); // Attention From 7022d50fe1e95eafa7771b74029ea0e3bac9b6d7 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 1 Oct 2025 10:10:48 +0200 Subject: [PATCH 018/141] [PyTorch] Quantizer as API (#2039) * Introduce QuantizerBase Signed-off-by: Evgeny * Expose as a first-class API Signed-off-by: Evgeny * Undo QuantizerBase Signed-off-by: Evgeny * Make Quantizer a base class without implementations Signed-off-by: Evgeny * Support CustomRecipe and CustomRecipeState Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Resolving comments: quantize impl, num_quantizers, defaults Signed-off-by: Evgeny * Quantizer factories Signed-off-by: Evgeny * Add tests Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * QuantizedTensorBase _get_quantizer() + quantize_() Signed-off-by: Evgeny * Experimental note + LayerNormMLP fix Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tensor._internal -> tensor.base Signed-off-by: Evgeny * Expose Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor import fix Signed-off-by: Evgeny * Single quantizer factory with roles Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More context for qfactory, fwd/bwd_roles Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor Signed-off-by: Evgeny * Rename *Base -> *Storage quantized tensors Signed-off-by: Evgeny * make_quantizers() will take roles from the operation Signed-off-by: Evgeny * Improve tests and fix missing imports Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Merge main followup Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evgeny Signed-off-by: Evgeny Tsykunov Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_custom_recipe.py | 290 ++++++++++++++++++ transformer_engine/common/recipe/__init__.py | 42 ++- .../debug/features/log_tensor_stats.py | 6 +- .../debug/pytorch/debug_quantization.py | 4 +- transformer_engine/pytorch/__init__.py | 15 + .../pytorch/cpp_extensions/gemm.py | 6 +- transformer_engine/pytorch/cpu_offload.py | 8 +- .../pytorch/csrc/extensions/cast.cpp | 4 +- .../pytorch/csrc/extensions/pybind.cpp | 36 +-- transformer_engine/pytorch/csrc/pybind.h | 16 +- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +- transformer_engine/pytorch/distributed.py | 80 ++--- .../pytorch/experimental/quantization.py | 6 +- transformer_engine/pytorch/fp8.py | 56 ++++ transformer_engine/pytorch/module/base.py | 60 ++-- .../pytorch/module/grouped_linear.py | 17 +- .../pytorch/module/layernorm_linear.py | 28 +- .../pytorch/module/layernorm_mlp.py | 56 ++-- transformer_engine/pytorch/module/linear.py | 30 +- transformer_engine/pytorch/ops/_common.py | 8 +- .../pytorch/ops/basic/basic_linear.py | 4 +- .../pytorch/ops/basic/dropout.py | 4 +- .../ops/fused/userbuffers_forward_linear.py | 4 +- transformer_engine/pytorch/tensor/__init__.py | 50 ++- .../pytorch/tensor/_internal/__init__.py | 4 - .../pytorch/tensor/float8_blockwise_tensor.py | 28 +- .../pytorch/tensor/float8_tensor.py | 38 ++- .../pytorch/tensor/mxfp8_tensor.py | 35 +-- .../pytorch/tensor/nvfp4_tensor.py | 10 +- .../pytorch/tensor/quantized_tensor.py | 90 ++++-- .../pytorch/tensor/storage/__init__.py | 9 + .../float8_blockwise_tensor_storage.py} | 10 +- .../float8_tensor_storage.py} | 14 +- .../mxfp8_tensor_storage.py} | 14 +- .../nvfp4_tensor_storage.py} | 12 +- transformer_engine/pytorch/tensor/utils.py | 8 +- transformer_engine/pytorch/utils.py | 8 +- 37 files changed, 808 insertions(+), 312 deletions(-) create mode 100644 tests/pytorch/test_custom_recipe.py delete mode 100644 transformer_engine/pytorch/tensor/_internal/__init__.py create mode 100644 transformer_engine/pytorch/tensor/storage/__init__.py rename transformer_engine/pytorch/tensor/{_internal/float8_blockwise_tensor_base.py => storage/float8_blockwise_tensor_storage.py} (98%) rename transformer_engine/pytorch/tensor/{_internal/float8_tensor_base.py => storage/float8_tensor_storage.py} (96%) rename transformer_engine/pytorch/tensor/{_internal/mxfp8_tensor_base.py => storage/mxfp8_tensor_storage.py} (97%) rename transformer_engine/pytorch/tensor/{_internal/nvfp4_tensor_base.py => storage/nvfp4_tensor_storage.py} (98%) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py new file mode 100644 index 000000000..cb840f197 --- /dev/null +++ b/tests/pytorch/test_custom_recipe.py @@ -0,0 +1,290 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine as te +import transformer_engine_torch as tex +from transformer_engine.common import recipe +from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast +from transformer_engine.pytorch import Linear +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear +from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.module.grouped_linear import GroupedLinear + + +@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"]) +def test_custom_recipe_sanity(module_type): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + # Simple linear layer with dims divisible by 16 + in_features = 64 + out_features = 64 + batch = 32 + + if module_type == "Linear": + model = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + elif module_type == "LayerNormLinear": + model = LayerNormLinear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + elif module_type == "LayerNormMLP": + # hidden_size == in_features == out_features for simplicity + model = LayerNormMLP( + hidden_size=in_features, ffn_hidden_size=out_features, params_dtype=torch.bfloat16 + ).cuda() + else: + # OpsLinear path + model = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Single factory: map roles to quantizers + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + # Execute with custom recipe + with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Basic sanity: gradients exist + assert inp.grad is not None + + +def test_custom_recipe_grouped_linear_sanity(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + num_gemms = 3 + in_features = 64 + out_features = 64 + batch = 32 + base = batch // num_gemms + rem = batch % num_gemms + m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)] + + model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda() + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + out = model(inp, m_splits) + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None + + +def test_custom_recipe_matches_current_scaling(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(123) + + in_features = 64 + out_features = 64 + batch = 32 + + # Create two identical models + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + model_custom = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + model_custom.load_state_dict(model_ref.state_dict()) + + # Identical inputs for both paths + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_custom = base_inp.clone().detach().requires_grad_(True) + + # Reference: use Float8CurrentScaling recipe + ref_recipe = recipe.Float8CurrentScaling() + with fp8_autocast(enabled=True, fp8_recipe=ref_recipe): + out_ref = model_ref(inp_ref) + # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) + ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3 + assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3 + assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3 + assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2 + assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2 + + # Stress dynamic range in grad_output + scale = torch.ones(out_features, device="cuda", dtype=torch.float32) + scale[0] = 1e8 + scale[1] = 1e-8 + loss_ref = (out_ref.float() * scale.view(1, -1)).sum() + loss_ref.backward() + + # Custom: single factory returning quantizers per role to match Float8CurrentScaling + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + out_custom = model_custom(inp_custom) + # Assert dtypes for custom quantizers match reference mapping + cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3 + assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 + assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 + assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + + loss_custom = (out_custom.float() * scale.view(1, -1)).sum() + loss_custom.backward() + + # Compare forward outputs (exact match expected) + assert torch.allclose(out_ref, out_custom, rtol=0.0, atol=0.0) + + # Compare input gradients + assert inp_ref.grad is not None and inp_custom.grad is not None + assert torch.allclose(inp_ref.grad, inp_custom.grad, rtol=0.0, atol=0.0) + + # Compare parameter gradients (weights and bias if present) + ref_params = dict(model_ref.named_parameters()) + custom_params = dict(model_custom.named_parameters()) + for name, p_ref in ref_params.items(): + p_cus = custom_params[name] + assert p_ref.grad is not None and p_cus.grad is not None + assert torch.allclose(p_ref.grad, p_cus.grad, rtol=0.0, atol=0.0) + + +def test_custom_recipe_ops_linear_2_1_layout(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(7) + + in_features = 64 + out_features = 64 + batch = 16 + + # Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer + op = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom = recipe.CustomRecipe(qfactory=quantizer_factory) + + with fp8_autocast(enabled=True, fp8_recipe=custom): + out = op(inp) + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None + + +def test_custom_recipe_factory_invocation_counts_and_cycling(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(13) + + in_features = 64 + out_features = 64 + batch = 8 + + op = Linear(in_features, out_features, params_dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Counters per role + counts = { + "linear_input": 0, + "linear_weight": 0, + "linear_output": 0, + "linear_grad_output": 0, + "linear_grad_input": 0, + } + + def quantizer_factory(role): + if role in counts: + counts[role] += 1 + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + + custom = recipe.CustomRecipe(qfactory=quantizer_factory) + + # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), + # and backward to build 2 quantizers (cycled from 1 factory). + with fp8_autocast(enabled=True, fp8_recipe=custom): + out = op(inp) + loss = out.float().sum() + loss.backward() + + # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input + assert counts["linear_input"] == 1 + assert counts["linear_weight"] == 1 + assert counts["linear_output"] == 1 + assert counts["linear_grad_output"] == 1 + assert counts["linear_grad_input"] == 1 + + +def test_factories_return_distinct_instances_and_buffers(): + available, reason = check_fp8_support() + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + # Two calls should produce distinct quantizer objects and distinct tensor buffers + def factory(): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + + q1 = factory() + q2 = factory() + + assert q1 is not q2 + assert q1.scale.data_ptr() != q2.scale.data_ptr() + assert q1.amax.data_ptr() != q2.amax.data_ptr() + + # Mutating one should not affect the other + q1.scale.fill_(123.0) + assert not torch.equal(q1.scale, q2.scale) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 179d618b3..324b5d50c 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -6,7 +6,8 @@ from __future__ import annotations import os from enum import Enum -from typing import Literal, Optional, Union, Callable, NamedTuple +from typing import Any, Literal, Optional, Union, Callable, NamedTuple +from dataclasses import field from pydantic.dataclasses import dataclass @@ -111,6 +112,10 @@ def float8_block_scaling(self): """Whether the given recipe is float8 blockwise scaling.""" return isinstance(self, Float8BlockScaling) + def custom(self): + """Whether the given recipe is custom.""" + return isinstance(self, CustomRecipe) + @dataclass() class DelayedScaling(Recipe): @@ -377,7 +382,6 @@ def __repr__(self) -> str: ) -@dataclass() class NVFP4BlockScaling(Recipe): """ Use the NVFP4 scaling strategy. @@ -456,3 +460,37 @@ def __repr__(self) -> str: f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " ) + + +@dataclass() +class CustomRecipe(Recipe): + """ + Custom recipe that allows users to provide quantizer factories. + + .. warning:: + **EXPERIMENTAL**: Custom recipe is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + + Parameters + ---------- + qfactory : Callable + Factory callable that returns a quantizer instance for a + given semantic tensor role. + The callable is typically invoked as: + qfactory( + role: str, + ) + + Where `role` is one of the following strings for e.g. te.Linear + (stable public contract): + - forward: "linear_input", "linear_weight", "linear_output" + - backward: "linear_grad_output", "linear_grad_input" + """ + + qfactory: Callable[..., Any] + + fp8_dpa: bool = False + fp8_mha: bool = False + + def __repr__(self) -> str: + return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 7ba2f9f77..5d721d996 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -15,8 +15,8 @@ from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase -from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params @@ -123,7 +123,7 @@ def inspect_tensor( """API call used to collect the data about the tensor before process_tensor()/quantization.""" assert ( - type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase] + type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage] and tensor.dtype != torch.uint8 ), ( f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using" diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index d564ca8e9..185bf15d0 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -18,7 +18,7 @@ from transformer_engine.pytorch.tensor.quantized_tensor import ( QuantizedTensor, Quantizer, - QuantizedTensorBase, + QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) @@ -557,7 +557,7 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): self._update_parent_quantizer_usage() -class DebugQuantizedTensor(QuantizedTensorBase): +class DebugQuantizedTensor(QuantizedTensorStorage): """ Class containing quantized tensors after debug. Depending on configuration it can contain one or two different objects. These objects can be accessed by the method diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 3bdbe4089..3256512b5 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -56,6 +56,21 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch.tensor import Quantizer +from transformer_engine.pytorch.tensor import Float8Quantizer +from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor import QuantizedTensorStorage +from transformer_engine.pytorch.tensor import Float8TensorStorage +from transformer_engine.pytorch.tensor import MXFP8TensorStorage +from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage +from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor import Float8Tensor +from transformer_engine.pytorch.tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor import prepare_for_saving +from transformer_engine.pytorch.tensor import restore_from_saved try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index d330e023e..a45fafb68 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -12,7 +12,7 @@ from ..utils import get_sm_count, _empty_tensor from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.utils import is_experimental from ..experimental.gemm import experimental_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -107,9 +107,9 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] - if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage): # There is not use_split_accumulator == False - # implementation for Float8BlockwiseQTensorBase GEMM + # implementation for Float8BlockwiseQTensorStorage GEMM use_split_accumulator = True # Check that data format is supported diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 179c80a65..9378774ea 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -10,7 +10,7 @@ import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .tensor.quantized_tensor import QuantizedTensorBase +from .tensor.quantized_tensor import QuantizedTensorStorage from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] @@ -34,7 +34,7 @@ def mark_activation_offload(*tensors): if tensor is not None: tensor.activation_offloading = True # This is a hack to force clear the tensor after it is offloaded. - # It is needed, because .*TensorBase classes are saved in the ctx, + # It is needed, because .*TensorStorage classes are saved in the ctx, # and they contain the reference to their data tensors. tensor.needs_force_clear = True @@ -362,7 +362,7 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: ), ) - is_quantized_tensor = isinstance(tensor, QuantizedTensorBase) + is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage) if not torch_stray_tensor: @@ -514,7 +514,7 @@ def synchronize_on_group_commit_forward(self, current_group): if tensor_tag[0] == self.offloaded_group_count: if hasattr(tensor_buf, "needs_force_clear"): # Need to clear activation tensor - sometimes references persist in the code. - # This is the case for example with the Float8TensorBase class, + # This is the case for example with the Float8TensorStorage class, # which is saved directly inside the ctx while its internal tensors are # saved inside save_for_backward. tensor_buf.data = torch.Tensor() diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2c1edae4c..b6e9ef828 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -314,7 +314,7 @@ std::tuple, std::vector> bulk_allocate_fp // Construct FP8 block-wise tensors py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); for (size_t i = 0; i < num_tensors; ++i) { // Create tensor objects with proper reference counting py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); @@ -461,7 +461,7 @@ std::tuple, std::vector> bulk_allocate_mx } // Construct mxfp8 tensors - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); for (size_t i = 0; i < num_tensors; ++i) { // Create tensor objects with proper reference counting py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 382adbfb0..3b81393db 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -23,17 +23,17 @@ namespace transformer_engine::pytorch { PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8TensorStoragePythonClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8TensorStoragePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; -PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; -PyTypeObject *NVFP4TensorBasePythonClass = nullptr; +PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; void init_float8_extension() { @@ -46,9 +46,9 @@ void init_float8_extension() { Float8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); - Float8TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage"); + Float8TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage")); NVTE_CHECK(Float8TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch Float8 extension."); } @@ -61,29 +61,29 @@ void init_mxfp8_extension() { MXFP8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); - MXFP8TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage"); + MXFP8TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage")); NVTE_CHECK(MXFP8TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch MXFP8 extension."); } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorBasePythonClass) return; + if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( - "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + "transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage"); Float8BlockwiseQuantizerClass = reinterpret_cast( PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); - Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage")); Float8BlockwiseQTensorPythonClass = reinterpret_cast( PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); - NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); @@ -97,9 +97,9 @@ void init_nvfp4_extensions() { NVFP4TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); auto nvfp4_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.nvfp4_tensor_base"); - NVFP4TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage"); + NVFP4TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage")); NVTE_CHECK(NVFP4TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch NVFP4 extension."); } diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index f46edaa70..65665d01b 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -31,17 +31,17 @@ namespace transformer_engine::pytorch { } while (false); extern PyTypeObject *Float8TensorPythonClass; -extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8TensorStoragePythonClass; extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; -extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8TensorStoragePythonClass; extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; -extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; -extern PyTypeObject *NVFP4TensorBasePythonClass; +extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; void init_extension(); @@ -55,13 +55,13 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { } inline bool IsFloat8Tensor(PyObject *obj) { - return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass; } inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Tensor(PyObject *obj) { - return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass; } inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { @@ -72,11 +72,11 @@ inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4Quant inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || - Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; + Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass; } inline bool IsNVFP4Tensor(PyObject *obj) { - return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorBasePythonClass; + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass; } TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 8470466ae..42ae658f2 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -152,7 +152,7 @@ std::pair Float8Quantizer::create_tensor( // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); @@ -357,7 +357,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); @@ -630,7 +630,7 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); ret = Float8BlockwiseQTensorClass( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, @@ -950,7 +950,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, @@ -1230,7 +1230,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorBasePythonClass)); + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); out_py = NVFP4TensorClass( "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 3ab0717d0..c001e8e79 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -41,11 +41,11 @@ from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensorBase, QuantizedTensor, Quantizer -from .tensor._internal.float8_tensor_base import Float8TensorBase -from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from .tensor._internal.nvfp4_tensor_base import NVFP4TensorBase -from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer +from .tensor.storage.float8_tensor_storage import Float8TensorStorage +from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from .triton.pad import pad_columnwise_scale_inv from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer @@ -907,7 +907,7 @@ def _all_gather_fp8( async_op: bool = False, quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, -) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]: """All-gather FP8 tensor along first dimension.""" world_size = get_distributed_world_size(process_group) @@ -925,7 +925,7 @@ def _all_gather_fp8( # Cast input tensor to FP8 if needed # Note: We cannot directly all-gather the transposed FP8 tensor, # so temporarily modify quantizer to avoid creating FP8 transpose. - if not isinstance(inp, Float8TensorBase): + if not isinstance(inp, Float8TensorStorage): assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer @@ -940,7 +940,7 @@ def _all_gather_fp8( ) # Construct output tensor - out: Float8TensorBase + out: Float8TensorStorage if quantizer is not None: dtype = torch.float32 device = "cuda" @@ -958,7 +958,7 @@ def _all_gather_fp8( out._transpose = None out._transpose_invalid = True else: - raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + raise RuntimeError("Float8TensorStorage is not supported yet without Quantizer") # Assume scaling factors are identical across ranks out._scale_inv = inp._scale_inv @@ -1003,10 +1003,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: def _post_process_fp8_blockwise_gather( - out: Float8BlockwiseQTensorBase, + out: Float8BlockwiseQTensorStorage, quantizer: Float8BlockQuantizer, handle: Optional[torch.distributed.Work] = None, -) -> Float8BlockwiseQTensorBase: +) -> Float8BlockwiseQTensorStorage: """Post-process FP8 blockwise gather.""" if handle is not None: handle.wait() @@ -1040,7 +1040,7 @@ def _post_process_fp8_blockwise_gather( class _FP8BlockwiseAllGatherAsyncHandle: """Handle for asynchronous FP8 blockwise all-gather.""" - tensor: Float8BlockwiseQTensorBase + tensor: Float8BlockwiseQTensorStorage quantizer: Float8BlockQuantizer async_handle: torch.distributed.Work _synchronized: bool = False @@ -1078,18 +1078,18 @@ def _all_gather_fp8_blockwise( if isinstance(inp, torch.Tensor): device = inp.device dtype = inp.dtype - elif isinstance(inp, Float8BlockwiseQTensorBase): + elif isinstance(inp, Float8BlockwiseQTensorStorage): if inp._rowwise_data is not None: device = inp._rowwise_data.device elif inp._columnwise_data is not None: device = inp._columnwise_data.device else: - raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data") dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " - f"found {inp.__class__.__name__})" + "Invalid type for input tensor (expected torch.Tensor or" + f" Float8BlockwiseQTensorStorage, found {inp.__class__.__name__})" ) world_size = get_distributed_world_size(process_group) @@ -1106,7 +1106,7 @@ def _all_gather_fp8_blockwise( # Doing BF16 gather for now as baseline because it's simpler if ( - not isinstance(inp, Float8BlockwiseQTensorBase) + not isinstance(inp, Float8BlockwiseQTensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1131,7 +1131,7 @@ def _all_gather_fp8_blockwise( # Set to compact usage in case the quantizer is not correctly configured orig_all_gather_usage = quantizer.all_gather_usage quantizer.all_gather_usage = True - if not isinstance(inp, Float8BlockwiseQTensorBase): + if not isinstance(inp, Float8BlockwiseQTensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1228,12 +1228,12 @@ def _swap_first_dims(tensor: torch.Tensor, world_size: int): def _post_process_nvfp4_gather( - out: NVFP4TensorBase, + out: NVFP4TensorStorage, columnwise_data_interleaved: torch.Tensor, columnwise_scale_inv_interleaved: torch.Tensor, world_size: int, handle: Optional[torch.distributed.Work] = None, -) -> NVFP4TensorBase: +) -> NVFP4TensorStorage: """Post-process FP8 blockwise gather.""" if handle is not None: handle.wait() @@ -1251,7 +1251,7 @@ def _post_process_nvfp4_gather( class _NVFP4AllGatherAsyncHandle: """Handle for asynchronous NVFP4 all-gather.""" - output: NVFP4TensorBase + output: NVFP4TensorStorage columnwise_data_interleaved: torch.Tensor columnwise_scale_inv_interleaved: torch.Tensor world_size: int @@ -1279,7 +1279,7 @@ def _all_gather_nvfp4( async_op: bool = False, quantizer: NVFP4Quantizer, out_shape: Optional[list[int]] = None, -) -> tuple[NVFP4TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: """All-gather NVFP4 tensor along first dimension.""" # Input tensor attributes @@ -1289,7 +1289,7 @@ def _all_gather_nvfp4( dtype: torch.dtype # Construct packed shapes for input and input_t. - if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorBase): + if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorStorage): # High-precision tensor. in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size()) in_shape_t = NVFP4Quantizer.convert_shape_for_fp4( @@ -1297,7 +1297,7 @@ def _all_gather_nvfp4( ) device = inp.device dtype = inp.dtype - elif isinstance(inp, NVFP4TensorBase): + elif isinstance(inp, NVFP4TensorStorage): if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device @@ -1307,7 +1307,7 @@ def _all_gather_nvfp4( dtype = torch.bfloat16 else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorBase, " + "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, " f"found {inp.__class__.__name__})" ) @@ -1321,7 +1321,7 @@ def _all_gather_nvfp4( # For cases where inp has dimensions that cannot be quantized, # we gather in high precision followed by a cast to NVFP4. if ( - not isinstance(inp, NVFP4TensorBase) + not isinstance(inp, NVFP4TensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1336,7 +1336,7 @@ def _all_gather_nvfp4( return out, None # Cast input tensor to NVFP4 with required data - if not isinstance(inp, NVFP4TensorBase): + if not isinstance(inp, NVFP4TensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1453,7 +1453,7 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, -) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" # Input tensor attributes @@ -1464,7 +1464,7 @@ def _all_gather_mxfp8( in_shape = inp.size() device = inp.device dtype = inp.dtype - elif isinstance(inp, MXFP8TensorBase): + elif isinstance(inp, MXFP8TensorStorage): if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device @@ -1476,7 +1476,7 @@ def _all_gather_mxfp8( dtype = torch.bfloat16 # Guess high-precision dtype. else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " + "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, " f"found {inp.__class__.__name__})" ) @@ -1488,7 +1488,7 @@ def _all_gather_mxfp8( # For cases where inp has dimensions that cannot be quantized, # we gather in high precision followed by a cast to FP8. if ( - not isinstance(inp, MXFP8TensorBase) + not isinstance(inp, MXFP8TensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1503,7 +1503,7 @@ def _all_gather_mxfp8( return out, None # Cast input tensor to MXFP8 with required data - if not isinstance(inp, MXFP8TensorBase): + if not isinstance(inp, MXFP8TensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1587,7 +1587,7 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(inp, QuantizedTensorBase): + if quantizer is not None and not isinstance(inp, QuantizedTensorStorage): inp = quantizer(inp) return inp, None @@ -1634,7 +1634,7 @@ def gather_along_first_dim( out_shape[0] *= world_size # FP8 case: delayed scaling or current scaling - if isinstance(inp, Float8TensorBase) or isinstance( + if isinstance(inp, Float8TensorStorage) or isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): return _all_gather_fp8( @@ -1646,7 +1646,9 @@ def gather_along_first_dim( ) # FP8 block scaling case, block length = 128 - if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance( + quantizer, Float8BlockQuantizer + ): return _all_gather_fp8_blockwise( inp, process_group, @@ -1656,7 +1658,7 @@ def gather_along_first_dim( ) # MXFP8 case - if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) return _all_gather_mxfp8( inp, @@ -1667,7 +1669,7 @@ def gather_along_first_dim( ) # NVFP4 case - if isinstance(inp, NVFP4TensorBase) or isinstance(quantizer, NVFP4Quantizer): + if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): assert isinstance(quantizer, NVFP4Quantizer) return _all_gather_nvfp4( inp, @@ -1683,7 +1685,7 @@ def gather_along_first_dim( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(inp, QuantizedTensorBase): + if isinstance(inp, QuantizedTensorStorage): inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format @@ -1701,7 +1703,7 @@ def gather_along_first_dim( return out, None # Dequantize quantized tensor if not supported - if isinstance(inp, QuantizedTensorBase): + if isinstance(inp, QuantizedTensorStorage): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/experimental/quantization.py index 9adf4dabf..7d573abac 100644 --- a/transformer_engine/pytorch/experimental/quantization.py +++ b/transformer_engine/pytorch/experimental/quantization.py @@ -13,7 +13,7 @@ import torch from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase, Quantizer +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer from transformer_engine.pytorch.experimental import utils @@ -36,7 +36,7 @@ class MMParams: @dataclasses.dataclass -class ExperimentalQuantizedTensor(QuantizedTensorBase): +class ExperimentalQuantizedTensor(QuantizedTensorStorage): """Base class for experimental quantized tensor containers. An experimental container to hold quantization result, including quantized tensor, optional @@ -187,7 +187,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, - ) -> QuantizedTensorBase: + ) -> QuantizedTensorStorage: raise NotImplementedError( f"{self.__class__.__name__} class does not implement make_empty function" ) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15017913f..a62e10bc5 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -22,6 +22,7 @@ Float8CurrentScaling, Float8BlockScaling, NVFP4BlockScaling, + CustomRecipe, ) from .constants import dist_group_type @@ -866,6 +867,8 @@ def create( cls = Float8BlockScalingRecipeState elif recipe.nvfp4(): cls = NVFP4BlockScalingRecipeState + elif recipe.custom(): + cls = CustomRecipeState else: raise ValueError(f"{recipe.__class__.__name__} is not supported") return cls( @@ -1191,3 +1194,56 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: ] raise RuntimeError(f"Unexpected recipe mode ({self.mode})") + + +class CustomRecipeState(RecipeState): + """State for CustomRecipe: produce quantizers per tensor.""" + + recipe: CustomRecipe + mode: str + num_quantizers: int + device: Optional[torch.device] + + def __init__( + self, + recipe: CustomRecipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + if device is None: + device = torch.device("cuda") + self.device = device + + if getattr(recipe, "qfactory", None) is None: + raise ValueError("CustomRecipe requires `qfactory`.") + + def make_quantizers(self) -> list: + qfactory = self.recipe.qfactory + out = [] + + # TODO(negvet): make_quantizers() should take roles from the operation + # Hardcode linear-specific roles for now + roles: List[str] + if self.mode == "forward": + roles = [ + ("linear_input", "linear_weight", "linear_output")[i % 3] + for i in range(self.num_quantizers) + ] + elif self.mode == "backward": + roles = [ + ("linear_grad_output", "linear_grad_input")[i % 2] + for i in range(self.num_quantizers) + ] + else: + roles = ["unknown"] * self.num_quantizers + + for i in range(self.num_quantizers): + # Get quantizer from the user defined factory + quantizer = qfactory(roles[i]) + out.append(quantizer) + return out diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index bf4fb97d2..d60ff8059 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -38,15 +38,15 @@ _fsdp_gather_tensors, ) from ..constants import dist_group_type -from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer +from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -505,7 +505,7 @@ def fill_userbuffers_buffer_for_all_gather( local_tensor: torch.Tensor, quantizer: Optional[Quantizer], process_group, -) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]: +) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]: """Fill local shard of Userbuffers buffer with data for all-gather Returns the full tensor and the local shard, both using the @@ -529,7 +529,7 @@ def fill_userbuffers_buffer_for_all_gather( # Unquantized data if quantizer is None: - if isinstance(local_tensor, QuantizedTensorBase): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor = local_tensor.dequantize() if comm.is_fp8_ubuf(): raise RuntimeError( @@ -542,8 +542,8 @@ def fill_userbuffers_buffer_for_all_gather( # FP8 data if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - if not isinstance(local_tensor, Float8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): + if not isinstance(local_tensor, Float8TensorStorage): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor.dequantize() quantizer.set_usage(rowwise=True, columnwise=False) local_tensor = quantizer(local_tensor) @@ -554,7 +554,7 @@ def fill_userbuffers_buffer_for_all_gather( ) comm.copy_into_buffer(local_tensor._data, local_chunk=True) global_tensor_data = comm.get_buffer(shape=global_shape) - global_tensor = Float8TensorBase( + global_tensor = Float8TensorStorage( data=global_tensor_data, fp8_scale_inv=local_tensor._scale_inv, fp8_dtype=local_tensor._fp8_dtype, @@ -566,8 +566,8 @@ def fill_userbuffers_buffer_for_all_gather( if isinstance(quantizer, MXFP8Quantizer): # Cast to MXFP8 if needed - if not isinstance(local_tensor, MXFP8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): + if not isinstance(local_tensor, MXFP8TensorStorage): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor.dequantize() local_tensor = quantizer(local_tensor) if not comm.is_fp8_ubuf(): @@ -622,7 +622,7 @@ def fill_userbuffers_buffer_for_all_gather( rowwise_data, rowwise_scale_inv = global_data, global_scale_inv else: columnwise_data, columnwise_scale_inv = global_data, global_scale_inv - global_tensor = MXFP8TensorBase( + global_tensor = MXFP8TensorStorage( rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, @@ -786,10 +786,10 @@ def _update_weight_quantizers(self) -> None: f"({len(weight_quantizers)}) must match" ) for weight, quantizer in zip(weight_tensors, weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): + if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_quantizer(quantizer) - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement _get_weight_tensors function" @@ -1038,8 +1038,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(self.fp8_meta["recipe"], "fp8_format"): + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) @@ -1170,9 +1171,9 @@ def grad_output_preprocess( grad_output, ( QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, ), ): grad_output = quantizer(grad_output) @@ -1201,9 +1202,9 @@ def grad_output_preprocess( grad_output_.get_tensor(True), ( QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, ), ) and ctx.use_bias @@ -1219,7 +1220,12 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance( grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ( + QuantizedTensor, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, + ), ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: @@ -1229,7 +1235,7 @@ def grad_output_preprocess( grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance(grad_output, QuantizedTensorBase): + if not isinstance(grad_output, QuantizedTensorStorage): grad_output = quantizer(grad_output) return grad_output, grad_bias @@ -1383,14 +1389,14 @@ def get_weight_workspace( # Reset cache if workspace is invalid if out is not None and quantizer is not None: reset_cache = False - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): if ( not is_non_tn_fp8_gemm_supported() and quantizer.columnwise_usage and out._transpose is None ): reset_cache = True - elif isinstance(out, MXFP8TensorBase): + elif isinstance(out, MXFP8TensorStorage): if quantizer.rowwise_usage and out._rowwise_data is None: reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: @@ -1581,7 +1587,7 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, QuantizedTensorBase): + if isinstance(tensor, QuantizedTensorStorage): quantizer = tensor._get_quantizer() if quantizer is None: continue diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5749d96c9..b3adfb7db 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -44,7 +44,7 @@ from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.quantized_tensor import ( - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -200,13 +200,13 @@ def forward( inputmats[0] = inp else: for inputmat in inputmats: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms if inp.requires_grad: for weight in weights_fp8: - if isinstance(weight, QuantizedTensorBase): + if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) tensors_to_save, tensor_objects = prepare_for_saving( @@ -338,7 +338,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) for weight, quantizer in zip(weights, ctx.weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): + if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_usage( rowwise_usage=quantizer.rowwise_usage, columnwise_usage=quantizer.columnwise_usage, @@ -734,7 +734,7 @@ def forward( produced) """ assert not isinstance( - inp, QuantizedTensorBase + inp, QuantizedTensorStorage ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." @@ -868,16 +868,17 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] - if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors): + if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " "Please make sure this is intentional." ) weight_tensors = [ - w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors + w.dequantize() if isinstance(w, QuantizedTensorStorage) else w + for w in weight_tensors ] return weight_tensors diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 6dbbd335e..e1c0eab2d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -58,7 +58,7 @@ from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers from ..tensor.quantized_tensor import ( QuantizedTensor, - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -66,8 +66,8 @@ from ...debug.pytorch.debug_state import TEDebugState from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload @@ -200,7 +200,7 @@ def forward( and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not experimental + and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) # Apply normalization @@ -278,7 +278,7 @@ def forward( weightmat = weight quantized_weight = False if fp8 or debug: - quantized_weight = not isinstance(weight, QuantizedTensorBase) + quantized_weight = not isinstance(weight, QuantizedTensorStorage) # Configure quantizer if weight_quantizer is not None: @@ -403,18 +403,18 @@ def forward( # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: - if isinstance(ln_out, QuantizedTensorBase): + if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data # can be allgathered. if ( - isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) + isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage)) or not ctx.ln_out_needs_gather ): ln_out.update_usage(rowwise_usage=False) # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading: @@ -685,9 +685,9 @@ def backward( # -------------------------------------------------- # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -806,14 +806,14 @@ def backward( ln_out_total_work.wait() ln_out_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): + if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -999,7 +999,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + # if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( @@ -1790,7 +1790,7 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a0e5f3aed..2097f01b1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -71,7 +71,7 @@ from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..tensor.quantized_tensor import ( - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -116,9 +116,14 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "swiglu": (tex.swiglu, tex.dswiglu, None), } # no activation fusion written yet - # Per-tensor current scaling or fp8 blockwise scaling: [] + # Per-tensor current scaling or fp8 blockwise scaling or custom quantization: [] # TODO(ksivaman): Fuse nvfp4 act once kernel is available. - if recipe.float8_current_scaling() or recipe.float8_block_scaling() or recipe.nvfp4(): + if ( + recipe.float8_current_scaling() + or recipe.float8_block_scaling() + or recipe.nvfp4() + or recipe.custom() + ): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), @@ -448,10 +453,18 @@ def forward( act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): - # tex.quantize does not support GELU fusion for blockwise. - act_out = activation_func(fc1_out, None) - act_out = tex.quantize(act_out, fc2_input_quantizer) + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise + act_out = activation_func(fc1_out, None) + act_out = tex.quantize(act_out, fc2_input_quantizer) + elif recipe.custom(): + # tex.quantize does not support custom quantizers + act_out = activation_func(fc1_out, None) + act_out = fc2_input_quantizer(act_out) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) else: if fp8_calibration: act_out = activation_func(fc1_out, None) @@ -522,9 +535,9 @@ def forward( if is_grad_enabled: # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(fc1_weight_final, QuantizedTensorBase): + if isinstance(fc1_weight_final, QuantizedTensorStorage): fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorBase): + if isinstance(fc2_weight_final, QuantizedTensorStorage): fc2_weight_final.update_usage(columnwise_usage=True) if cpu_offloading: @@ -823,10 +836,10 @@ def backward( ) # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorBase + ctx.fc2_weight, QuantizedTensorStorage ): ctx.fc2_weight.update_usage(columnwise_usage=True) @@ -908,14 +921,14 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.fp8 or ctx.debug: - if isinstance(act_out, QuantizedTensorBase): + if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1023,10 +1036,13 @@ def fc2_wgrad_gemm( ) # activation in high precision if ctx.fp8: - # TODO float8 blockwise current scaling has no bgrad fusion for now + # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now # TODO(ksivaman): Re-add fusion once kernel is available. - if isinstance( - ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + if ( + isinstance( + ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) + ) + or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) @@ -1072,7 +1088,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorBase + ctx.fc1_weight_quantizer, QuantizedTensorStorage ): ctx.fc1_weight.update_usage(columnwise_usage=True) @@ -1143,7 +1159,7 @@ def fc2_wgrad_gemm( ln_out_total_work.wait() ln_out_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): + if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1153,7 +1169,7 @@ def fc2_wgrad_gemm( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.fp8 or ctx.debug: - if isinstance(dact, QuantizedTensorBase): + if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -2153,7 +2169,7 @@ def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" return [self.fc1_weight, self.fc2_weight] diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index cf7f58947..02872439a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -59,7 +59,7 @@ from ..graph import is_graph_capturing from ..tensor.quantized_tensor import ( QuantizedTensor, - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -178,7 +178,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase) and not experimental: + if not isinstance(inputmat, QuantizedTensorStorage) and not experimental: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -216,7 +216,7 @@ def forward( else: # Do not all-gather input tensor if fp8 or debug: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat.update_usage(rowwise_usage=True) else: if input_quantizer is None: @@ -372,7 +372,7 @@ def forward( if ( backward_needs_input and own_quantized_input - and isinstance(inputmat, QuantizedTensorBase) + and isinstance(inputmat, QuantizedTensorStorage) ): if ( ctx.backward_input_needs_gather @@ -391,7 +391,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM. if inp.requires_grad: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading and saved_inputmat is not None: @@ -404,7 +404,7 @@ def forward( ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, - weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None, + weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -613,7 +613,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.requires_wgrad: if ctx.fp8 or ctx.debug: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass elif ctx.debug or ctx.experimental: @@ -632,7 +632,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype) @@ -677,9 +677,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): + if ctx.weight_quantizer is not None and isinstance( + weight_fp8, QuantizedTensorStorage + ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -763,7 +765,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work.wait() inputmat_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(inputmat_total, QuantizedTensorBase): + if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) @@ -805,7 +807,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -958,7 +960,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( wgrad, @@ -1524,7 +1526,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad): for name, q in zip(names, original_quantizers) ) - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 99bbc34c4..13db35fc7 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -13,17 +13,17 @@ from .. import torch_version from ..fp8 import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor -from ..tensor.quantized_tensor import QuantizedTensorBase +from ..tensor.quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype -def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool: +def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" - return isinstance(tensor, QuantizedTensorBase) + return isinstance(tensor, QuantizedTensorStorage) def maybe_dequantize( - tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None + tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None ) -> torch.Tensor: """Dequantize tensor to given dtype or just convert if not a quantized tensor""" if is_quantized_tensor(tensor): diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index f8f95cf19..844e49ff0 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -29,7 +29,7 @@ ) from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -568,7 +568,7 @@ def _functional_forward( # Prepare input tensor for backward pass if weight_requires_grad: if with_quantized_compute and is_quantized_tensor(x_local): - if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): + if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 30ccf5ebc..38b2a59a7 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -11,7 +11,7 @@ import transformer_engine_torch as tex from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -56,7 +56,7 @@ def op_forward( out = input_ elif impl == "fused": x = input_ - if not isinstance(x, Float8TensorBase): + if not isinstance(x, Float8TensorStorage): x = maybe_dequantize(x, dtype=dtype) out, mask = tex.dropout_fwd(x, self.dropout_probability) elif impl == "unfused": diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index a604e57dc..cbbe529d6 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -23,7 +23,7 @@ ) from ...tensor.quantized_tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_dequantize, is_quantized_tensor from ..basic import BasicLinear, Bias, ReduceScatter from ..op import ( @@ -267,7 +267,7 @@ def _functional_forward( # Prepare input tensor for backward pass if weight_requires_grad: if with_quantized_compute and is_quantized_tensor(x_local): - if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): + if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 43846512d..7689e2019 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,12 +6,42 @@ import torch -from .quantized_tensor import QuantizedTensor, Quantizer +from .quantized_tensor import ( + QuantizedTensorStorage, + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from .storage.float8_tensor_storage import Float8TensorStorage +from .storage.mxfp8_tensor_storage import MXFP8TensorStorage +from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer +from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ - "QuantizedTensor", "Quantizer", + "Float8Quantizer", + "Float8CurrentScalingQuantizer", + "MXFP8Quantizer", + "Float8BlockQuantizer", + "NVFP4Quantizer", + "QuantizedTensorStorage", + "Float8TensorStorage", + "MXFP8TensorStorage", + "Float8BlockwiseQTensorStorage", + "NVFP4TensorStorage", + "QuantizedTensor", + "Float8Tensor", + "MXFP8Tensor", + "Float8BlockwiseQTensor", + "NVFP4Tensor", + "prepare_for_saving", + "restore_from_saved", ] @@ -48,24 +78,16 @@ def get_all_tensor_types(): """ Get all tensor-like types that can be used in TE. """ - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase - from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( - Float8BlockwiseQTensor, - Float8BlockwiseQTensorBase, - ) - from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Tensor, NVFP4TensorBase - all_tensor_types = [ torch.Tensor, torch.nn.Parameter, Float8Tensor, - Float8TensorBase, + Float8TensorStorage, MXFP8Tensor, - MXFP8TensorBase, + MXFP8TensorStorage, Float8BlockwiseQTensor, - Float8BlockwiseQTensorBase, + Float8BlockwiseQTensorStorage, NVFP4Tensor, - NVFP4TensorBase, + NVFP4TensorStorage, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_internal/__init__.py b/transformer_engine/pytorch/tensor/_internal/__init__.py deleted file mode 100644 index e13014bf7..000000000 --- a/transformer_engine/pytorch/tensor/_internal/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. -"""Internal data structures for quantized tensors.""" diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0e41fc9c5..16631a3d0 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -13,8 +13,12 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine.common.recipe import Float8BlockScaling, Recipe -from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .quantized_tensor import ( + QuantizedTensor, + Quantizer, + _IdentityFunc, +) from ..utils import devices_match, round_up_to_nearest_multiple aten = torch.ops.aten @@ -101,6 +105,10 @@ def update_quantized( dst._fp8_dtype = self.dtype return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: """Calculate the shape of the scaling tensor for blockwise quantization. @@ -270,7 +278,7 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling -class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): +class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. The tensor presents as having a standard, higher-precision dtype, @@ -295,7 +303,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): holds configuration about quantization and dequantization modes. """ - # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, @@ -334,15 +342,6 @@ def __repr__(self, *, tensor_contents=None): f" data_format={self._data_format}" ) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - assert self._quantizer is not None - return self._quantizer - def quantize_( self, tensor: torch.Tensor, @@ -361,8 +360,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize()) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 18750d039..a4e68e53b 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -13,8 +13,12 @@ from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match -from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func +from .quantized_tensor import ( + QuantizedTensor, + Quantizer, + _IdentityFunc, +) from ..constants import dist_group_type aten = torch.ops.aten @@ -89,6 +93,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def make_empty( self, shape: Iterable[int], @@ -147,7 +155,7 @@ def create_tensor_from_data( torch.float8_e5m2fnuz, ] if internal: - return Float8TensorBase( + return Float8TensorStorage( data=data, fp8_scale_inv=1 / self.scale, fp8_dtype=self.dtype, @@ -271,6 +279,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def make_empty( self, shape: Iterable[int], @@ -333,7 +345,7 @@ def create_tensor_from_data( torch.float8_e5m2fnuz, ] if internal: - return Float8TensorBase( + return Float8TensorStorage( data=data, fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, @@ -388,7 +400,7 @@ def supports_only_rowwise_all_gather(self) -> bool: return True -class Float8Tensor(Float8TensorBase, QuantizedTensor): +class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -443,19 +455,6 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: return _FromFloat8Func.apply(self, dtype) return _FromFloat8Func.forward(None, self, dtype) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - if self._quantizer is not None: - return self._quantizer - # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling) - raise ValueError( - "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable" - ) - def quantize_( self, tensor: torch.Tensor, @@ -474,8 +473,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def detach(self) -> Float8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d7f5f8c7d..700de24c4 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -16,8 +16,12 @@ from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple -from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func +from .quantized_tensor import ( + QuantizedTensor, + Quantizer, + _IdentityFunc, +) aten = torch.ops.aten @@ -67,6 +71,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: @@ -161,14 +169,14 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor) return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) - def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor: + def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor: return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling -class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): +class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -192,7 +200,7 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): """ - # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, @@ -236,17 +244,9 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: return _FromMXFP8Func.apply(self, dtype) return _FromMXFP8Func.forward(None, self, dtype) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - if self._quantizer is not None: - return self._quantizer - return MXFP8Quantizer( - fp8_dtype=self._fp8_dtype, - ) + def _build_default_quantizer(self) -> Optional[Quantizer]: + """Build default quantizer for the tensor""" + return MXFP8Quantizer(fp8_dtype=self._fp8_dtype) def quantize_( self, @@ -266,8 +266,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize()) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def detach(self) -> MXFP8Tensor: # pylint: disable=missing-function-docstring diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index b12e89956..ca2154f55 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -21,7 +21,7 @@ round_up_to_nearest_multiple, ) -from ._internal.nvfp4_tensor_base import NVFP4TensorBase, _FromNVFP4Func +from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc aten = torch.ops.aten @@ -173,6 +173,10 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: @@ -332,7 +336,7 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return NVFP4BlockScaling -class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): +class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): """Quantized tensor class with FP4 data The tensor presents as having a standard, higher-precision dtype, @@ -365,7 +369,7 @@ class NVFP4Tensor(NVFP4TensorBase, QuantizedTensor): Nominal tensor datatype, used in dequantize. """ - # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 7b88d2519..a524d5c8d 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -5,7 +5,7 @@ """Tensor with quantized data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Any, Dict, Union +from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union import abc import copy import warnings @@ -13,12 +13,11 @@ import torch from torch.utils._pytree import tree_map -import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -class QuantizedTensorBase: - r"""Base class for all *TensorBase classes. +class QuantizedTensorStorage: + r"""Base class for all *TensorStorage classes. This class (and its subclasses) are optimization for when the full QuantizedTensor is not needed (when it is fully @@ -26,9 +25,9 @@ class QuantizedTensorBase: PyTorch's autograd). When creating a new tensor type X one should create both - XTensorBase class inheriting from QuantizedTensorBase and - XTensor inheriting from XTensorBase and QuantizedTensor. - XTensorBase should contain all data members needed to + XTensorStorage class inheriting from QuantizedTensorStorage and + XTensor inheriting from XTensorStorage and QuantizedTensor. + XTensorStorage should contain all data members needed to implement the functionality of the tensor, while XTensor should only implement the functionality needed to behave like regular torch.Tensor (liek __torch_dispatch__).""" @@ -59,7 +58,7 @@ def update_usage( f"{self.__class__.__name__} class does not implement update_usage function" ) - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement prepare_for_saving function" @@ -73,6 +72,30 @@ def restore_from_saved( f"{self.__class__.__name__} class does not implement restore_from_saved function" ) + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return self._build_default_quantizer() + + def _build_default_quantizer(self) -> Quantizer: + """Build default quantizer for the tensor""" + raise ValueError( + f"{self.__class__.__name__} has no quantizer " + "and no default quantizer is available defined in the subclass." + ) + + def quantize_( + self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None + ) -> QuantizedTensor: + """Quantize tensor in-place""" + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + def update_quantizer(self, quantizer: Quantizer): """Update quantizer for the tensor""" if self._quantizer is None: @@ -83,13 +106,13 @@ def update_quantizer(self, quantizer: Quantizer): def prepare_for_saving( - *tensors: Union[torch.Tensor, QuantizedTensorBase], + *tensors: Union[torch.Tensor, QuantizedTensorStorage], ) -> Tuple[ - list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]] + list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]] ]: """Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save - the internal TensorBase types too.""" + the internal *TensorStorage types too.""" tensor_list, tensor_objects_list = [], [] for tensor in tensors: @@ -104,12 +127,12 @@ def prepare_for_saving( def restore_from_saved( - tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]], + tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], return_saved_tensors: bool = False, ) -> ( - list[Optional[torch.Tensor | QuantizedTensorBase]] - | tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]] + list[Optional[torch.Tensor | QuantizedTensorStorage]] + | tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]] ): """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] @@ -178,7 +201,6 @@ def __repr__(self): ")" ) - @abc.abstractmethod def update_quantized( self, src: torch.Tensor, @@ -187,6 +209,9 @@ def update_quantized( noop_flag: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Quantize tensor in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_quantized" + ) def quantize( self, @@ -199,8 +224,14 @@ def quantize( if out is not None: return self.update_quantized(tensor, out) if (not self.internal) and torch.is_grad_enabled(): - return _QuantizeFunc.apply(tensor, self) - return _QuantizeFunc.forward(None, tensor, self) + return _QuantizeFunc.apply(tensor, self.quantize_impl) + return _QuantizeFunc.forward(None, tensor, self.quantize_impl) + + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement quantize_impl function" + ) def multi_quantize(self, list_of_tensors): """Quantize multiple tensors""" @@ -213,7 +244,6 @@ def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor""" return self.quantize(tensor) - @abc.abstractmethod def make_empty( self, shape: Iterable[int], @@ -222,8 +252,11 @@ def make_empty( device: Optional[torch.device] = None, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function, " + "required for construction of unintialized quantized tensor" + ) - @abc.abstractmethod def calibrate(self, tensor: torch.Tensor) -> None: """Calibrate quantizer state @@ -252,13 +285,21 @@ def copy(self) -> Quantizer: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Symbolic function for ONNX export""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement onnx_quantize" + ) def onnx_dequantize(self, tensor) -> torch.Tensor: """Symbolic function for ONNX export""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement onnx_dequantize" + ) - @abc.abstractmethod def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Returns recipe class that is compatible with this quantizer""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_compatible_recipe" + ) def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" @@ -270,20 +311,21 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-a class _QuantizeFunc(torch.autograd.Function): - """Cast to FP8 from other dtype""" + """Quantize tensor""" @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused tensor: torch.Tensor, - quantizer: Quantizer, + quantize_impl: Callable, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring - return tex.quantize(tensor, quantizer) + return quantize_impl(tensor) @staticmethod def backward( - _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, ) -> Tuple[Optional[torch.Tensor], ...]: # pylint: disable=missing-function-docstring # Assume that we want gradients in full precision diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py new file mode 100644 index 000000000..9cb228f3a --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Storage for quantized tensors.""" + +from .float8_tensor_storage import Float8TensorStorage # noqa: F401 +from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 +from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py similarity index 98% rename from transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index da0220eb7..9040ea3a4 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -13,7 +13,7 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import Float8BlockScaleTensorFormat -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType_To_Torch @@ -22,7 +22,7 @@ from ...utils import _empty_tensor -class Float8BlockwiseQTensorBase(QuantizedTensorBase): +class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8BlockwiseQTensor. Float8BlockwiseQTensor inherits from the PyTorch tensor class and this @@ -53,7 +53,7 @@ def __new__( *args, **kwargs, ): - if cls is Float8BlockwiseQTensorBase: + if cls is Float8BlockwiseQTensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -98,7 +98,7 @@ def _is_gemm_ready_format(self) -> bool: def prepare_for_saving( self, - ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]: """ Prepare the tensor base for saving for backward """ @@ -366,7 +366,7 @@ def __repr__(self): data = self.dequantize() descriptor = "columnwise" return ( - "Float8BlockwiseQTensorBase(" + "Float8BlockwiseQTensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"{descriptor}_scaled_data={data}" ) diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py similarity index 96% rename from transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 6d4822344..b9533edb6 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -12,7 +12,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType as torch_to_transformer_engine_dtype @@ -27,7 +27,7 @@ class _FromFloat8Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: Float8TensorBase, + tensor: Float8TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -52,7 +52,7 @@ def backward( return grad, None -class Float8TensorBase(QuantizedTensorBase): +class Float8TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8Tensor. Float8Tensor inherits from the PyTorch tensor class and this mixin @@ -81,7 +81,7 @@ def __new__( quantizer: Optional[Quantizer] = None, **kwargs, ): - if cls is Float8TensorBase: + if cls is Float8TensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -116,7 +116,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [self._data, self._transpose, self._scale_inv] self._data = None @@ -163,7 +163,7 @@ def view(self, shape: torch.Size): if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]: out_transpose = None - return Float8TensorBase( + return Float8TensorStorage( data=out_data, fp8_scale_inv=self._scale_inv, fp8_dtype=self._fp8_dtype, @@ -173,7 +173,7 @@ def view(self, shape: torch.Size): def __repr__(self): return ( - "Float8TensorBase(" + "Float8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " f"data={self.dequantize()}" diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py similarity index 97% rename from transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 5a7dd6b44..c1f30146c 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -13,7 +13,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ...constants import TE_DType as torch_to_transformer_engine_dtype @@ -28,7 +28,7 @@ class _FromMXFP8Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: MXFP8TensorBase, + tensor: MXFP8TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -49,7 +49,7 @@ def backward( return grad, None -class MXFP8TensorBase(QuantizedTensorBase): +class MXFP8TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of MXFP8Tensor. MXFP8Tensor inherits from the PyTorch tensor class and this mixin @@ -77,7 +77,7 @@ def __new__( *args, **kwargs, ): - if cls is MXFP8TensorBase: + if cls is MXFP8TensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -112,7 +112,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [ self._rowwise_data, @@ -192,7 +192,7 @@ def view(self, shape: torch.Size): if cur_columnwise_data is not None: new_columnwise_data = cur_columnwise_data.view(*shape) - return MXFP8TensorBase( + return MXFP8TensorStorage( rowwise_data=new_rowwise_data, rowwise_scale_inv=self._rowwise_scale_inv, columnwise_data=new_columnwise_data, @@ -205,7 +205,7 @@ def __repr__(self): data_rowwise = self.dequantize() return ( - "MXFP8TensorBase(" + "MXFP8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"rowwise_scaled_data={data_rowwise}" f"rowwise_scale_inv={self._rowwise_scale_inv}, " diff --git a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py similarity index 98% rename from transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index df187d674..350103f7c 100644 --- a/transformer_engine/pytorch/tensor/_internal/nvfp4_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -16,7 +16,7 @@ # import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage # from ...constants import TE_DType as torch_to_transformer_engine_dtype from ..quantized_tensor import Quantizer @@ -39,7 +39,7 @@ class _FromNVFP4Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: NVFP4TensorBase, + tensor: NVFP4TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -89,7 +89,7 @@ def backward( return grad, None -class NVFP4TensorBase(QuantizedTensorBase): +class NVFP4TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of NVFP4Tensor. NVFP4Tensor inherits from the PyTorch tensor class and this mixin @@ -161,7 +161,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [ self._rowwise_data, @@ -267,7 +267,7 @@ def view(self, shape: torch.Size): new_columnwise_data = self._columnwise_data.view(byte_shape) # Construct tensor - return NVFP4TensorBase( + return NVFP4TensorStorage( rowwise_data=new_rowwise_data, rowwise_scale_inv=self._rowwise_scale_inv, columnwise_data=new_columnwise_data, @@ -282,7 +282,7 @@ def __repr__(self): data_rowwise = self.dequantize() return ( - "NVFP4TensorBase(" + "NVFP4TensorStorage(" f"rowwise_scaled_data={data_rowwise}," f"rowwise_scale_inv={self._rowwise_scale_inv}," f"amax_rowwise={self._amax_rowwise}," diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index a4bdf5e07..cc0249401 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorBase +from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer @@ -454,7 +454,7 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ) -def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) -> bool: +def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: """Check if an environment or object is using experimental Kitchen middleware. Returns False if x is a torch.Tensor. @@ -466,6 +466,6 @@ def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorBase]] = None) - # Detect if the object is experimental if isinstance(x, torch.Tensor): return False - if not isinstance(x, (Quantizer, QuantizedTensorBase)): - raise AssertionError("Object must be a Quantizer or QuantizedTensorBase instance") + if not isinstance(x, (Quantizer, QuantizedTensorStorage)): + raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance") return hasattr(x, "experimental") and x.experimental diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 1a0722f89..8ea362371 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -225,13 +225,15 @@ def forward( ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections from transformer_engine.pytorch.float8_tensor import Float8Tensor - from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase + from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( + Float8TensorStorage, + ) - if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( + if isinstance(mixed_x_layer, Float8TensorStorage) and not isinstance( mixed_x_layer, Float8Tensor ): return tuple( - Float8TensorBase( + Float8TensorStorage( fp8_scale_inv=mixed_x_layer._scale_inv, fp8_dtype=mixed_x_layer._fp8_dtype, data=x.squeeze(split_dim) if squeeze else x, From ac4e0fd63afb1998904695e1321c5631192c3a85 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 1 Oct 2025 10:02:26 -0400 Subject: [PATCH 019/141] [JAX] Rework amax reduction over TPSP (#2218) * rm using_global_amax_of_x Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/quantization.py | 44 +++++++++++++------ transformer_engine/jax/dense.py | 16 +++---- 2 files changed, 35 insertions(+), 25 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 021af4c9d..9f9e8fec0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -551,7 +551,10 @@ class AmaxCalculationPrimitive(BasePrimitive): name = "jax_local_amax" multiple_results = False - impl_static_args = (1,) # amax_scope + impl_static_args = ( + 1, + 2, + ) # amax_scope, batch_sequence_transpose inner_primitive = None outer_primitive = None @@ -560,11 +563,12 @@ def abstract( x_aval, *, amax_scope, + batch_sequence_transpose, ): """ amax calcuation abstract """ - del amax_scope + del amax_scope, batch_sequence_transpose dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -576,17 +580,19 @@ def abstract( def impl( x, amax_scope, + batch_sequence_transpose, ): """ amax calcuation implementation """ - del amax_scope + del amax_scope, batch_sequence_transpose amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) return amax @staticmethod def infer_sharding_from_operands( amax_scope, + batch_sequence_transpose, mesh, arg_infos, result_infos, @@ -594,7 +600,7 @@ def infer_sharding_from_operands( """ amax calcuation infer_sharding_from_operands """ - del (amax_scope, arg_infos, result_infos) # Unused. + del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused. amax_sharding = NamedSharding( mesh, PartitionSpec(None), @@ -605,6 +611,7 @@ def infer_sharding_from_operands( @staticmethod def partition( amax_scope, + batch_sequence_transpose, mesh, arg_infos, result_infos, @@ -613,25 +620,26 @@ def partition( amax calcuation partition """ del result_infos - + x_spec = get_padded_spec(arg_infos[0]) amax_sharding = NamedSharding( mesh, PartitionSpec(None), - desc="AmaxCalculationPrimitive.out_sharding", + desc="AmaxCalculation.amax_sharding", ) def sharded_impl(x): amax = AmaxCalculationPrimitive.impl( x, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) - if amax_scope is AmaxScope.TPSP: # Run AR across TP/SP - gmesh = global_mesh_resource() - amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tp_resource, mesh) + gmesh = global_mesh_resource() + sequence_dim = 0 if batch_sequence_transpose else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource: amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) - - if amax_scope is AmaxScope.FSDP: # Run AR across FSDP - gmesh = global_mesh_resource() + # Run AR across FSDP + if amax_scope is AmaxScope.FSDP: amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) return amax @@ -640,11 +648,11 @@ def sharded_impl(x): return mesh, sharded_impl, amax_sharding, arg_shardings @staticmethod - def shardy_sharding_rule(amax_scope, mesh, value_types, result_types): + def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types): """ amax calcuation shardy_sharding_rule """ - del amax_scope, mesh, result_types + del amax_scope, batch_sequence_transpose, mesh, result_types prefix = "AmaxCal" input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) output_spec = (f"{prefix}_amax",) @@ -701,6 +709,7 @@ def _quantize_dbias_impl( dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling + batch_sequence_transpose: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -745,6 +754,8 @@ def _quantize_dbias_impl( quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis, + amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias @@ -760,6 +771,7 @@ def _quantize_dbias_impl( amax = AmaxCalculationPrimitive.outer_primitive.bind( x.data, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: @@ -833,6 +845,7 @@ def quantize( quantizer: Quantizer, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, + batch_sequence_transpose: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -853,6 +866,7 @@ def quantize( quantizer=quantizer, flatten_axis=flatten_axis, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) return out @@ -863,6 +877,7 @@ def quantize_dbias( is_dbias: bool = True, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, + batch_sequence_transpose: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -889,6 +904,7 @@ def quantize_dbias( is_dbias=is_dbias, flatten_axis=flatten_axis, amax_scope=amax_scope, + batch_sequence_transpose=batch_sequence_transpose, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 23df1a0ce..3cdf6ba7a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -67,7 +67,6 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None, - using_global_amax_of_x: bool = False, collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, ): @@ -86,7 +85,6 @@ def dense( input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix output_axes: Logical axes for sharding the output - using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types @@ -109,14 +107,13 @@ def dense( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) def _dense( x, kernel, @@ -126,7 +123,6 @@ def _dense( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, # need to be a diff_arg for DelayedScaling state management ): @@ -144,7 +140,6 @@ def _dense( input_axes: Logical axes for sharding the activation input output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix - using_global_amax_of_x: Indicate wether to use global amax for x. Only works when using current-scaling. Default is False. collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types @@ -160,7 +155,6 @@ def _dense( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, ) @@ -176,7 +170,6 @@ def _dense_fwd_rule( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, quantizer_set, ): @@ -203,7 +196,8 @@ def _dense_fwd_rule( x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, - amax_scope=AmaxScope.TPSP if using_global_amax_of_x else AmaxScope.LOCAL, + amax_scope=AmaxScope.TPSP, + batch_sequence_transpose=batch_sequence_transpose, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -250,7 +244,6 @@ def _dense_bwd_rule( input_axes, kernel_axes, output_axes, - using_global_amax_of_x, collective_op_set, ctx, grad, @@ -280,7 +273,8 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, - amax_scope=AmaxScope.LOCAL if using_global_amax_of_x else AmaxScope.TPSP, + amax_scope=AmaxScope.TPSP, + batch_sequence_transpose=batch_sequence_transpose, ) # GEMM NT From b0d562d8ac3f0ce36131471ae03d87b90a797e6f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 1 Oct 2025 10:13:40 -0400 Subject: [PATCH 020/141] [JAX] Fix `rng_state` shape in fused attention (#2217) fix rng_state shape Signed-off-by: Phuong Nguyen Co-authored-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/attention.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 625f42049..db2537c38 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1820,7 +1820,7 @@ def ring_attn_fwd_impl( # RNG shape should be the shared shape. This is unused for ring attention as we do not # support dropout currently. - rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:]) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): @@ -2306,7 +2306,7 @@ def fwd_impl( # RNG shape should be the shared shape. This is unused for ring attention as we do not # support dropout currently. - rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:]) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): From ac886c3594a80e05ad6682b13e3099e3bdc8248d Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 1 Oct 2025 19:32:16 +0200 Subject: [PATCH 021/141] [PyTorch] Fix QuantizedTensorBase -> QuantizedTensorStorage (#2226) Fix QuantizedTensorBase -> QuantizedTensorStorage Signed-off-by: Evgeny --- .../attention/dot_product_attention/backends.py | 6 +++--- .../dot_product_attention/context_parallel.py | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index f72c1eb9e..3a1375838 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -25,7 +25,7 @@ Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.quantized_tensor import ( - QuantizedTensorBase, + QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) @@ -1312,7 +1312,7 @@ def backward(ctx, d_out): # d_out is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorBase): + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): d_out = ctx.dO_quantizer(d_out) if not ctx.use_FAv2_bwd: d_out._data = d_out._data.contiguous() @@ -1479,7 +1479,7 @@ def backward(ctx, d_out): ctx.dP_quantizer, ) else: - if isinstance(d_out, QuantizedTensorBase): + if isinstance(d_out, QuantizedTensorStorage): d_out = d_out.dequantize(dtype=ctx.nominal_dtype) dqkv_te_dtype = TE_DType[d_out.dtype] # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 539caffbb..d0ddae25e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -21,7 +21,7 @@ ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorBase +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( dist_group_type, @@ -1823,7 +1823,7 @@ def backward(ctx, dout): # dout is expected to be in FP8 if is_output_fp8=True, # but in the case it's not, convert it to FP8 before any operation - if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorBase): + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): dout = ctx.dO_quantizer(dout) if ctx.use_fused_attention: dout._data = dout._data.contiguous() @@ -1997,7 +1997,7 @@ def backward(ctx, dout): dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: - if isinstance(dout, QuantizedTensorBase): + if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) dq_buffer = torch.empty_like(q) p2p_comm_buffers = [ @@ -3396,7 +3396,7 @@ def backward(ctx, dout): if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if not isinstance(dout, QuantizedTensorBase): + if not isinstance(dout, QuantizedTensorStorage): dout = ctx.dO_quantizer(dout) dout_fp8 = dout dqkv_te_dtype = dout._fp8_dtype @@ -3409,7 +3409,7 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if isinstance(dout, QuantizedTensorBase): + if isinstance(dout, QuantizedTensorStorage): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} From f0a9404881777ba0496e56e62d682ebb3896e91c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 1 Oct 2025 10:33:05 -0700 Subject: [PATCH 022/141] Fix hang during debug build (#2221) Disable debug build for cutlass GEMM Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index a4915080e..e0fe3c04a 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -164,6 +164,9 @@ else() message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") endif() +# Disable debug build for cutlass due to hang. +set_source_files_properties("gemm/cutlass_grouped_gemm.cu" PROPERTIES COMPILE_FLAGS "-g0") + # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas From 90449f796718022fd34ae518c7f4a37df0fc76f2 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 1 Oct 2025 14:09:38 -0700 Subject: [PATCH 023/141] Convert `NVFP4BlockScaling` to dataclass (#2227) Fix passing args to nvfp4 recipe Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 324b5d50c..1a9b02987 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -382,6 +382,7 @@ def __repr__(self) -> str: ) +@dataclass() class NVFP4BlockScaling(Recipe): """ Use the NVFP4 scaling strategy. From aee5a82108bc3053ef01fa1bd2459fe7c0a154f5 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 1 Oct 2025 14:12:15 -0700 Subject: [PATCH 024/141] Fix the cuBLAS workspace alignment (#2223) * Fix the cublas workspace alignment Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak Signed-off-by: Przemyslaw Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- transformer_engine/common/gemm/cublaslt_gemm.cu | 16 ++++++++++++---- transformer_engine/pytorch/module/base.py | 4 ++-- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index ab80fe769..a4810881c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -679,6 +679,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif } + // align the workspace to 256 B + const int required_alignment = 256; + const auto original_workspace_alignment = _getAlignment(reinterpret_cast(workspace)); + uint8_t *aligned_workspace_ptr = + reinterpret_cast(workspace) + required_alignment - original_workspace_alignment; + workspaceSize = workspaceSize - required_alignment + original_workspace_alignment; + const auto new_workspace_alignment = + _getAlignment(reinterpret_cast(aligned_workspace_ptr)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); @@ -686,7 +694,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); - const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -695,8 +702,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); - NVTE_CHECK(workspace_alignment % 256 == 0, - "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); + NVTE_CHECK(new_workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", + new_workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, @@ -714,7 +722,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, C, /* C */ Cdesc, D, /* D */ Ddesc, &heuristicResult.algo, /* algo */ - workspace, /* workspace */ + aligned_workspace_ptr, /* workspace */ workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d60ff8059..3ae389568 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -78,8 +78,8 @@ class UserBufferQuantizationMode(Enum): def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - # 32 MiB for NVFP4 GEMM, plus 256 B for misc scales - return 32 * 1024 * 1024 + 256 + # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales + return 32 * 1024 * 1024 + 1024 return 4_194_304 From c1003181dbd5123a3e349266e8dc118f89d78485 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:42:36 -0700 Subject: [PATCH 025/141] [PyTorch] Set usages for linear op quantizers before forward (#2222) * Make sure to set usages for linear op quantizers before forward Signed-off-by: Tim Moon * Avoid unsupported case for fused dbias+quantize kernel Hopper does not support dbias + FP8 cast without FP8 transpose. Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon --- tests/pytorch/distributed/test_fusible_ops.py | 215 +++++++++++++++++- .../pytorch/csrc/extensions/bias.cpp | 23 +- .../pytorch/ops/basic/basic_linear.py | 76 ++++--- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 4 + transformer_engine/pytorch/ops/op.py | 11 + 8 files changed, 296 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 11fe4333b..af0f0e931 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -635,6 +635,204 @@ def _test_linear( torch.testing.assert_close(db_test, db_ref, **tols) +def _test_mlp( + *, + bias: bool = True, + hidden_size: int = 32, + local_batch_size: int = 32, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str] = None, + quantized_weight: bool = False, + sequence_parallel: bool = False, +) -> None: + """2-layer MLP + + MLP includes GELU activation in order to test op fusions. Model + performs warmup steps in order to test inter-step logic. + + """ + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + mlp_size = hidden_size * world_size + batch_size = local_batch_size + if sequence_parallel: + batch_size *= world_size + in_shape = (batch_size, hidden_size) + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w1_ref, w1_test = make_reference_and_test_tensors( + (mlp_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b1_ref, b1_test = None, None + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, mlp_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = None, None + if bias: + b1_ref, b1_test = make_reference_and_test_tensors( + (mlp_size,), + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = make_reference_and_test_tensors( + (world_size, hidden_size), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w1_ref) + if bias: + y_ref += b1_ref + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w2_ref) + if bias: + y_ref += b2_ref.sum(dim=0) + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + local_mlp_size = mlp_size // world_size + local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size) + dx_ref = x_ref.grad + dw1_ref = w1_ref.grad[local_mlp_slice, :] + w1_ref = w1_ref[local_mlp_slice, :] + w1_test = w1_test[local_mlp_slice, :] + dw2_ref = w2_ref.grad[:, local_mlp_slice] + w2_ref = w2_ref[:, local_mlp_slice] + w2_test = w2_test[:, local_mlp_slice] + if bias: + db1_ref = b1_ref.grad[local_mlp_slice] + b1_ref = b1_ref[local_mlp_slice] + b1_test = b1_test[local_mlp_slice] + db2_ref = b2_ref.grad[rank, :] + b2_ref = b2_ref[rank, :] + b2_test = b2_test[rank, :] + else: + db1_ref = None + db2_ref = None + if sequence_parallel: + local_batch_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + x_ref = x_ref[local_batch_slice, ...] + dx_ref = dx_ref[local_batch_slice, ...] + x_test = x_test[local_batch_slice, ...].clone() + y_ref = y_ref[local_batch_slice, ...] + dy_ref = dy_ref[local_batch_slice, ...] + dy_test = dy_test[local_batch_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model = te_ops.Sequential( + te_ops.GELU(), + te_ops.Linear( + hidden_size, + mlp_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="column", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + te_ops.Linear( + mlp_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="row", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + ) + with torch.no_grad(): + model[1].weight.copy_(w1_test) + model[3].weight.copy_(w2_test) + if bias: + model[1].bias.copy_(b1_test) + model[3].bias.copy_(b2_test) + del w1_test, w2_test, b1_test, b2_test + + # Warmup steps + for _ in range(3): + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + x_test.grad = None + model[1].weight.grad = None + model[3].weight.grad = None + if bias: + model[1].bias.grad = None + model[3].bias.grad = None + + # Forward and backward step + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw1_test, dw1_ref, **tols) + torch.testing.assert_close(dw2_test, dw2_ref, **tols) + if bias: + db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") + db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db1_test, db1_ref, **tols) + torch.testing.assert_close(db2_test, db2_ref, **tols) + + def _test_fp8_scale_update( *, amax_history_len: int = 31, @@ -801,16 +999,31 @@ def run_parallel_tests() -> None: for config in itertools.product( quantization_list, ("column", "row"), + (False, True), ): if rank == 0: print(f"Running _test_linear with {config=}") - quantization, tensor_parallel_mode = config + quantization, tensor_parallel_mode, sequence_parallel = config dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, + sequence_parallel=sequence_parallel, + ) + + # MLP + for config in itertools.product(quantization_list, (False, True)): + if rank == 0: + print(f"Running _test_mlp with {config=}") + quantization, sequence_parallel = config + dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + _test_mlp( + bias=True, # bias=False is tested in _test_basic_linear + dtype=dtype, + quantization=quantization, + sequence_parallel=sequence_parallel, ) # FP8 scale update diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0531596dd..b0435d272 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -54,10 +54,25 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; } - // Unfused impl if quantizer is not supported - const bool with_fused_dbias_quantize_kernel = - detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); - if (!with_fused_dbias_quantize_kernel) { + // Check if fused kernel is supported + bool with_fused_kernel = false; + if (detail::IsFloat8Quantizers(quantizer.ptr())) { + auto prop = at::cuda::getCurrentDeviceProperties(); + const size_t sm_arch = 10 * prop->major + prop->minor; + if (sm_arch >= 100) { + // Fused kernel for dbias + FP8 cast on SM arch 10.0+ + with_fused_kernel = true; + } else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) { + // Fused kernel for dbias + FP8 cast + FP8 transpose + with_fused_kernel = true; + } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Fused kernel for dbias + MXFP8 quantize + with_fused_kernel = true; + } + + // Apply unfused impl if fused kernel is not supported + if (!with_fused_kernel) { at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 844e49ff0..cb2119296 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -322,6 +322,20 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + weight_requires_grad = requires_grad and self.weight.requires_grad + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -352,6 +366,35 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: and not getattr(self, "_with_quantized_weight", False) ) + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + if recipe.nvfp4(): + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin @@ -731,7 +774,7 @@ def _functional_backward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(columnwise=True) + input_quantizer.set_usage(rowwise=False, columnwise=True) if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -912,42 +955,13 @@ def op_forward( input_requires_grad = ctx.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - # Configure quantizers - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - - # Recipe-specific configuration - recipe = FP8GlobalStateManager.get_fp8_recipe() - if recipe.float8_current_scaling(): - input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group - if recipe.nvfp4(): - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Get autocast dtype if needed if torch.is_autocast_enabled(): diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 02bcfee0a..ab271e17b 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -85,7 +85,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 15cc081c1..4831ae407 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -79,7 +79,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 21190d4fc..72e17f64e 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -58,7 +58,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index ccd7ee52b..6f80a7a1f 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -472,6 +472,10 @@ def __call__( # Attempt to fuse operations if neccesary self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) + # Initialization before forward + for idx, op in enumerate(self._basic_ops): + op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) + # Fuser forward pass if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 903bc49d5..103ebf241 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -65,6 +65,13 @@ def is_fused_op(self) -> bool: def pre_first_fuser_forward(self) -> None: """Preprocessing before first fuser forward pass""" + def pre_fuser_forward( + self, + *, + requires_grad: bool, # pylint: disable=unused-argument + ) -> None: + """Preprocessing before fuser forward pass""" + def get_input_quantizer(self) -> Optional[Quantizer]: """Get builder class for quantized input tensor""" @@ -710,6 +717,10 @@ def pre_first_fuser_forward(self) -> None: for op in self.basic_ops: op.pre_first_fuser_forward() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + for op in self.basic_ops: + op.pre_fuser_forward(requires_grad=requires_grad) + def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin From f936c2ac82f348deba74180eea1732a55e118cc6 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 2 Oct 2025 06:31:55 -0700 Subject: [PATCH 026/141] [JAX] Fix code block in fp8_autocast docstring (#2228) Fix code block in fp8_autocast docstring Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- transformer_engine/jax/quantize/helper.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 3d460e81a..67f0a68c6 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -404,20 +404,20 @@ def fp8_autocast( This context manager enables FP8 quantization for the duration of its context. .. code-block:: python - mesh_shape = (4, 2) - dp_mesh_axis_name = 'data_parallel' - tp_mesh_axis_name = 'tensor_parallel' - devices = np.asarray(jax.devices()).reshape(*mesh_shape) + mesh_shape = (4, 2) + dp_mesh_axis_name = 'data_parallel' + tp_mesh_axis_name = 'tensor_parallel' + devices = np.asarray(jax.devices()).reshape(*mesh_shape) - with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): - mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) + with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): + mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) - with fp8_autocast(enabled=True, mesh_resource=mesh_resource): - rules = extend_logical_axis_rules(tuple()) - transformer = TransformerLayer() + with fp8_autocast(enabled=True, mesh_resource=mesh_resource): + rules = extend_logical_axis_rules(tuple()) + transformer = TransformerLayer() - with partitioning.axis_rules(rules): - pjit(transformer.init, ...)(...) + with partitioning.axis_rules(rules): + pjit(transformer.init, ...)(...) .. note:: We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`, From be7f43f10ce34ce3f878a63933a6dd45eb10bafc Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 2 Oct 2025 06:32:24 -0700 Subject: [PATCH 027/141] [JAX] Fix shard map issue when `get_all_mesh_axes()` is used (#2229) Fix shard map issue Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- transformer_engine/jax/sharding.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 7a8261269..d3a7952d3 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -131,7 +131,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): # We want to exclude the axes that already used by shard_map and shard_map # only sets those in the abstract_mesh, not the physical one manual_axis_names = get_abstract_mesh().manual_axes - cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec) + + # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too + def filter_manual_axes(name_or_tuple): + if isinstance(name_or_tuple, tuple): + out = tuple(n for n in name_or_tuple if n not in manual_axis_names) + if len(out) == 0: + return None + return out + if name_or_tuple in manual_axis_names: + return None + return name_or_tuple + + cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec) + + if cleaned_axis_names == (None,) * len(cleaned_axis_names): + return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) return jax.lax.with_sharding_constraint(x, cleaned_pspec) From e30c36a30883c49b820096a0bb856c7ea71bebd5 Mon Sep 17 00:00:00 2001 From: hx Date: Fri, 3 Oct 2025 05:04:28 +0800 Subject: [PATCH 028/141] [PyTorch] fix int32 overflow in permute kernels (#2196) * fix overflow of int32 in permute kernels Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/triton/permutation.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index ceb88108f..6292acb69 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -324,7 +324,8 @@ def _permute_kernel( pid_h = tl.program_id(1) cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = cur_off < hidden_size - input_off = pid_t * stride_input_token + cur_off * stride_input_hidden + src_row = pid_t.to(tl.int64) + input_off = src_row * stride_input_token + cur_off * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) if PERMUTE_SCALE: mask_scale = cur_off < scale_hidden_dim @@ -338,7 +339,7 @@ def _permute_kernel( for idx in tl.range(n_routed): dst_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ) + ).to(tl.int64) output_off = dst_row * stride_output_token + cur_off * stride_output_hidden if PERMUTE_SCALE: permuted_scale_off = ( @@ -519,7 +520,7 @@ def _unpermute_kernel( for idx in tl.range(n_routed): src_row = tl.load( row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ) + ).to(tl.int64) input_off = src_row * stride_input_token + current_offset * stride_input_hidden inp = tl.load(input_ptr + input_off, mask=mask) inp = inp.to(compute_type) @@ -550,7 +551,8 @@ def _unpermute_kernel( prob = tl.load(permuted_probs_ptr + permuted_prob_off) tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) accumulator = accumulator.to(data_type) - output_off = pid_t * stride_output_token + current_offset * stride_output_hidden + dst_row = pid_t.to(tl.int64) + output_off = dst_row * stride_output_token + current_offset * stride_output_hidden tl.store(output_ptr + output_off, accumulator, mask=mask) @@ -681,7 +683,7 @@ def _unpermute_bwd_with_merging_probs_kernel( for idx in tl.range(n_routed): dst_row = tl.load( row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert - ) + ).to(tl.int64) expert_idx = tl.load( row_id_map_ptr + pid * stride_row_id_map_token @@ -692,8 +694,10 @@ def _unpermute_bwd_with_merging_probs_kernel( while current_start < hidden_size: current_offset = current_start + tl.arange(0, BLOCK_SIZE) mask = current_offset < hidden_size + src_row = pid.to(tl.int64) input_off = ( - pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden + src_row * stride_fwd_output_grad_token + + current_offset * stride_fwd_output_grad_hidden ) inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) inp = inp.to(compute_type) @@ -902,11 +906,11 @@ def _sort_chunks_by_map_kernel( pid_t = tl.program_id(0) pid_h = tl.program_id(1) if FORWARD: - src_row = pid_t - dst_row = tl.load(row_id_map_ptr + pid_t) + src_row = pid_t.to(tl.int64) + dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) else: - src_row = tl.load(row_id_map_ptr + pid_t) - dst_row = pid_t + src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) + dst_row = pid_t.to(tl.int64) current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) mask = current_offset < hidden_size input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden From b840898b75162bce68fbc3c9c8234b6f23dcdbff Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Fri, 3 Oct 2025 09:39:46 -0700 Subject: [PATCH 029/141] [JAX] Clamped Swiglu Integration (#2194) Signed-off-by: Varun Thumbe *Jax integration for clamped swiglu. This is the continuation of PR which added Clamped Swiglu(used in GPT OSS) support in TE along with Pytorch integration. This PR hooks up the clamped swiglu and dswiglu's nvte APIs to TE Jax. --- tests/jax/test_custom_call_compute.py | 89 ++++++--- .../include/transformer_engine/activation.h | 1 + .../common/util/cast_gated_kernels.cuh | 14 +- transformer_engine/jax/activation.py | 26 ++- .../jax/cpp_extensions/activation.py | 176 +++++++++++++++--- transformer_engine/jax/csrc/extensions.h | 17 ++ .../jax/csrc/extensions/activation.cpp | 52 ++++-- .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/jax/flax/module.py | 13 +- transformer_engine/jax/flax/transformer.py | 5 + transformer_engine/jax/layernorm_mlp.py | 18 +- 11 files changed, 324 insertions(+), 88 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7f15eec89..7a4fa268a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -170,6 +170,7 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { @@ -182,17 +183,21 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -209,12 +214,20 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -234,7 +247,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -242,9 +256,21 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -273,10 +299,18 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -296,10 +330,18 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) @@ -734,6 +776,7 @@ def test_quantize_dbias( def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) @@ -785,7 +828,7 @@ def _test_quantize_dact_dbias( (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. or ( - activation_type == ("squared_relu",) + activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")} and in_dtype == jnp.bfloat16 and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING ) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index e50d71040..4e4808858 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU }; /*! \brief Computes the GeLU activation of the input. diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index ca37a2831..93086bd82 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -924,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1006,7 +1006,7 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { checkCuDriverContext(stream); @@ -1138,7 +1138,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: @@ -1155,7 +1154,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); break; @@ -1180,7 +1178,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } template -void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { +void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(input, "gated_act_input"); CheckOutputTensor(*output, "gated_act_output"); NVTE_CHECK(input.flat_last_dim() % 2 == 0, @@ -1213,7 +1211,7 @@ void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t st template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p, +void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(input, "dgated_act_input"); @@ -1252,7 +1250,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamO template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p, +void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, cudaStream_t stream) { constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); @@ -1318,7 +1316,7 @@ namespace detail { template void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - ParamOP &p, cudaStream_t stream) { + ParamOP p, cudaStream_t stream) { using namespace gated_kernels; Tensor grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43..daa3679c4 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp - from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor @@ -22,6 +21,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[tex.activation.ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -32,17 +32,19 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream @@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index a8c14a608..925c1d01a 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -5,6 +5,7 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial +from dataclasses import dataclass import jax import jax.numpy as jnp @@ -12,9 +13,9 @@ from jax.experimental.custom_partitioning import SdyShardingRule from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -51,17 +52,87 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } -def _convert_to_activation_function(fn_or_string): +@dataclass(frozen=True) +class ClampedSwigluParams: + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + limit: float = 7.0 + alpha: float = 1.702 + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work. + + Returns: + int: Hash value of the dataclass instance. + """ + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + + +@dataclass(frozen=True) +class ActivationParams: + """Parameters for various activation functions. + Currently only Clamped SwiGLU activation has parameters. + """ + + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + CLAMPED_ACTIVATION_TYPES = { + ("clamped_silu", "clamped_linear"), + "clamped_silu", + "clamped_linear", + } + if activation_type in CLAMPED_ACTIVATION_TYPES: + return ActivationParams(ClampedSwigluParams(**kwargs)) + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work""" + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + # This function is used for ClampedSwiGLU + # used in GPT OSS where the gates are not only clamped + # but also shifted by +1 + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha + return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -84,7 +155,8 @@ class ActLuPrimitive(BasePrimitive): 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params inner_primitive = None outer_primitive = None @@ -100,11 +172,12 @@ def abstract( is_2x, scale_dtype, is_outer, + act_params, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 @@ -150,6 +223,7 @@ def lowering( is_2x, scale_dtype, is_outer, + act_params, ): """ te_gated_act_lu_p lowering rules @@ -158,9 +232,14 @@ def lowering( x_aval, scale_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + ctx, + x, + scale, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + is_2x=is_2x, + act_params=act_params.to_ffi_lowering_dict(), ) return out @@ -175,6 +254,7 @@ def impl( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe implementation @@ -193,6 +273,7 @@ def impl( is_2x=is_2x, scale_dtype=scale_dtype, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -221,6 +302,7 @@ def batcher( is_2x, scale_dtype, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -242,6 +324,7 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + act_params=act_params, ), out_bdims, ) @@ -255,6 +338,7 @@ def infer_sharding_from_operands( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -266,6 +350,7 @@ def infer_sharding_from_operands( scale_dtype, act_len, is_outer, + act_params, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -318,6 +403,7 @@ def partition( is_2x, scale_dtype, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -378,6 +464,7 @@ def sharded_impl(x, scale): is_2x=is_2x, scale_dtype=scale_dtype, is_outer=True, + act_params=act_params, ) ) @@ -405,11 +492,12 @@ def shardy_sharding_rule( is_2x, scale_dtype, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types + del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params prefix = "ActLu_" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] @@ -455,8 +543,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params + impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) inner_primitive = None outer_primitive = None @@ -474,11 +562,12 @@ def abstract( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p abstract """ - del act_enum + del act_enum, act_params dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype @@ -575,6 +664,7 @@ def lowering( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p lowering rules @@ -593,6 +683,7 @@ def lowering( is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), ) @staticmethod @@ -608,6 +699,7 @@ def impl( act_enum, act_len, is_outer, + act_params, ): """ te_dact_dbias_quantize_p impl @@ -627,6 +719,7 @@ def impl( act_enum=act_enum, act_len=act_len, is_outer=False, + act_params=act_params, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -655,6 +748,7 @@ def batcher( act_enum, act_len, is_outer, + act_params, ): """ to describe batch rules for vmap @@ -685,6 +779,7 @@ def batcher( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, ), out_bdims, ) @@ -699,11 +794,12 @@ def infer_sharding_from_operands( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum + del out_dtype, result_infos, act_enum, act_params del scale_dtype, act_len, is_outer x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -774,6 +870,7 @@ def partition( act_enum, act_len, is_outer, + act_params, mesh, arg_infos, result_infos, @@ -854,6 +951,7 @@ def sharded_impl(dz, x, scale): act_enum=act_enum, act_len=act_len, is_outer=True, + act_params=act_params, ) ) if is_dbias: @@ -880,11 +978,13 @@ def shardy_sharding_rule( act_enum, act_len, is_outer, + act_params, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types + + del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 @@ -923,20 +1023,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -951,10 +1053,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -962,7 +1066,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -985,6 +1090,7 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1008,24 +1114,22 @@ def act_lu( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) - + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) - if quantizer is None: out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( x, @@ -1037,6 +1141,7 @@ def act_lu( is_2x=False, scale_dtype=jnp.float32, is_outer=True, + act_params=act_params, ) out = out.reshape(output_shape) out = NoScaleTensor( @@ -1051,6 +1156,7 @@ def act_lu( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, _ = _quantize_dbias_impl( out, @@ -1060,7 +1166,6 @@ def act_lu( amax_scope=amax_scope, ) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1080,6 +1185,7 @@ def act_lu( is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), is_outer=True, + act_params=act_params, ) quantizer.update(updated_amax) @@ -1102,6 +1208,7 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1118,7 +1225,7 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -1131,8 +1238,7 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) - + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1148,6 +1254,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) output = output.astype(x.dtype) dbias = None @@ -1163,7 +1270,11 @@ def quantize_dact_dbias( # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, ) return _quantize_dbias_impl( out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1180,6 +1291,7 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params, ) if war_output is not None: return war_output @@ -1191,6 +1303,7 @@ def quantize_dact_dbias( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 @@ -1203,7 +1316,10 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, ) out, dbias = _quantize_dbias_impl( dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 @@ -1229,6 +1345,7 @@ def quantize_dact_dbias( act_enum=act_type_id, act_len=act_len, is_outer=True, + act_params=act_params, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1257,6 +1374,7 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1270,11 +1388,13 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params, ) return output diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 2ab95002f..bbfc62120 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -36,6 +36,15 @@ namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -137,6 +146,14 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember("clamped_swiglu")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index b2b3db52c..0ecf79150 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + bool is_2x_int, ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -125,6 +128,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -145,17 +152,19 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("is_2x") + .Attr("act_params"), FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int) { + JAXX_Scaling_Mode scaling_mode, bool is_2x_int, + ActivationConfig act_params) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, - act_enum, scaling_mode, is_2x_int); + act_enum, scaling_mode, is_2x_int, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, @@ -170,7 +179,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ret() // amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x")); + .Attr("is_2x") + .Attr("act_params")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -240,7 +250,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + int64_t act_enum, bool is_2x, bool is_dbias, + ActivationConfig act_params) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -407,6 +421,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -432,21 +450,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias"), + .Attr("is_dbias") + .Attr("act_params"), FFI_CudaGraph_Traits); -Error_Type DActLuDBiasQuantizeInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, - Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t act_enum, - bool is_2x, bool is_dbias) { +Error_Type DActLuDBiasQuantizeInitializeFFI( + cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, + Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, + bool is_dbias, ActivationConfig act_params) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias); + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -466,7 +483,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") - .Attr("is_dbias")); + .Attr("is_dbias") + .Attr("act_params")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 36dd8205b..23d46b338 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -143,6 +143,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54ef..f02876d8f 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -898,6 +898,10 @@ class LayerNormMLP(TransformerEngineBase): activations: Sequence[Union[str, Callable]], default = ('relu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. intermediate_dropout_rate: float, default = 0.1 @@ -956,6 +960,7 @@ class LayerNormMLP(TransformerEngineBase): bias_axes_2: Tuple[str, ...] = ("embed",) return_layernorm_output: bool = True activations: Sequence[Union[str, Callable]] = ("relu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" intermediate_dropout_rate: float = 0.1 intermediate_hidden_dropout_dims: Sequence[int] = () @@ -1023,6 +1028,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1037,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1150,6 +1158,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1287,4 +1296,4 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype - return out, ln_output # Output, layner_norm_output + return out, ln_output # Output, layer_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index ad66684f2..868bcfa05 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1632,6 +1632,9 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mlp_activations: Sequence[str], default = ('relu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1752,6 +1755,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None mlp_activations: Sequence[str] = ("relu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False @@ -2046,6 +2050,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index cf77f8e0a..77daa4672 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -50,6 +50,7 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, collective_op_sets: Tuple[tex.CollectiveOpSet] = ( tex.noop_collective_op_set, tex.noop_collective_op_set, @@ -138,13 +139,14 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -165,6 +167,7 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): @@ -220,6 +223,7 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ) @@ -246,6 +250,7 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, quantizer_sets, ): @@ -335,6 +340,11 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -402,6 +412,7 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, collective_op_sets, ctx, grad, @@ -497,6 +508,11 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim From dfe5b7dfc2288afc5d2f247709b1e0328af331e4 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Fri, 3 Oct 2025 20:09:41 +0200 Subject: [PATCH 030/141] [Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell (#2157) * Update to_string(NVTEScalingMode) to include block scaling Signed-off-by: Jan Bielak * Add `nvte_swizzle_block_scaling_to_mxfp8_scaling_factors` Signed-off-by: Jan Bielak * Convert FP8 block scaling tensors to MXFP8 tensors on Blackwell and newer in GEMM Signed-off-by: Jan Bielak * Allow Blackwell and newer in Deepseek recipe compatbility check Signed-off-by: Jan Bielak * Allow data_rows % 4 != 0 in 1d kernel Signed-off-by: Jan Bielak * Load scaling factors in unswizzled order in 1d kernel Signed-off-by: Jan Bielak * Enforce use of power of two scaling Signed-off-by: Jan Bielak * Skip the FP8 block scaling exact GEMM test on Blackwell Signed-off-by: Jan Bielak * Skip further tests with pow_2_scales=False Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Initial implementation of tensor conversion for grouped gemm Signed-off-by: Jan Bielak * Skip non power of two scaling cpp unit tests Signed-off-by: Jan Bielak * Fix handling of all gather Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jan Bielak * Use compute capability 10.0 for logic with Blackwell Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Jan Bielak Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../cpp/operator/test_cast_float8blockwise.cu | 12 + .../test_float8_blockwise_gemm_exact.py | 4 +- .../test_float8_blockwise_scaling_exact.py | 14 + transformer_engine/common/CMakeLists.txt | 1 + .../include/transformer_engine/swizzle.h | 20 ++ .../common/swizzle/swizzle_block_scaling.cu | 321 ++++++++++++++++++ .../common/transformer_engine.cpp | 4 + .../quantize_transpose_square_blockwise.cu | 6 + .../quantize_transpose_vector_blockwise.cu | 6 + .../pytorch/csrc/extensions/gemm.cpp | 99 ++++-- transformer_engine/pytorch/csrc/util.cpp | 70 ++++ transformer_engine/pytorch/csrc/util.h | 12 + transformer_engine/pytorch/distributed.py | 8 +- transformer_engine/pytorch/fp8.py | 11 +- 14 files changed, 553 insertions(+), 35 deletions(-) create mode 100644 transformer_engine/common/swizzle/swizzle_block_scaling.cu diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index e5faa688c..fe4ae2d26 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 2u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. @@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 1u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index ec23cfe8c..bdc73519b 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -19,7 +20,8 @@ def fp8_blockwise_gemm_supported() -> bool: supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() - return supported + emulated = get_device_compute_capability() >= (10, 0) + return supported and not emulated def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 858ce73b6..51e0d1ec9 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -12,6 +12,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8BlockScaling +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( Float8BlockQuantizer, @@ -32,6 +33,7 @@ if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() +recipe_emulated = get_device_compute_capability() >= (10, 0) class GetRecipes: @@ -218,6 +220,12 @@ def check_quantization_block_tiling_versus_reference( pow_2_scales: bool, tile_size: Tuple[int, int], ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): block_scaling_dim = 1 @@ -409,6 +417,12 @@ def test_quantization_block_tiling_extrema_versus_reference( tile_size: Tuple[int, int], extrema_high: bool, ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + # This test runs a single tile through a quantizer as a way to test # branch coverage of scale computation. te_dtype = TE_DType[quant_dtype] diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e0fe3c04a..92b57897d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -127,6 +127,7 @@ list(APPEND transformer_engine_SOURCES util/multi_stream.cpp util/rtc.cpp swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 079feb4a7..624e71d1e 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM + * + * \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv. + * \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it + * not natively supported by cublasLt on architectures other than Hopper. + + * Requirements: + * - input is an FP8 block scaling tensor + * - input has rowwise usage + * - input.scale_inv is in GEMM_READY format + * - output is an MXFP8 tensor + * - output has rowwise usage + * - output.scale_inv has appropriate shape + * */ +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu new file mode 100644 index 000000000..4be85474a --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -0,0 +1,321 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace { +constexpr uint32_t WARP_SIZE = 32; +} // namespace +namespace swizzle_kernel_1d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where +// each thread stores a single row (of four bytes). +// Example: +// lane0.row = 0x00010203 +// lane1.row = 0x04050607 +// lane2.row = 0x08090a0b +// lane3.row = 0x0c0d0e0f +// Becomes: +// lane0.row = 0x0004080c +// lane1.row = 0x0105090d +// lane2.row = 0x02060a0e +// lane3.row = 0x03070b0f +uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row, + const uint32_t lane, + const uint32_t active_mask) { + using cu = const uint32_t; + + // Threads operate in groups of 4, and each thread stores 4 bytes at a time. + // The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes + // until we have transposed the 4x4 matrix. + cu m_0123_4567_89ab_cdef = row; + cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4); + cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240); + cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715); + cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea; + cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4); + cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410); + cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276); + cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15; + + return m_048c_159d_26ae_37bf; +} + +// Expands a uint32_t to a uint4 by duplicating each byte four times. +// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404} +uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) { + return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222), + __byte_perm(x, 0, 0x3333)}; +} + +// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data +// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors. +struct no_oob_tag_t {}; +constexpr no_oob_tag_t NO_OOB_TAG; + +template +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride, + OOBT first_oob) { + // resolve kernel variant + constexpr bool no_oob = std::is_same_v; + static_assert(no_oob || std::is_same_v); + + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_x; + const uint32_t in_tile_x = out_tile_y; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factors for this lane's initial four 1x128 tiles + uint4 sf; + if constexpr (no_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + if ((out_tile_y < tiles_y - 1) || lane < first_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + sf = uint4{0, 0, 0, 0}; + } + } + + // pack the exponent bits of the scaling factors + uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); + + // partially swizzle the scaling factors + constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches + const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); + packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx); + + // transpose 4x4 matrices of scaling factors + packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK); + + // broadcast the scaling factors for sixteen 1x32 tiles + sf = broadcast_uint32_t_to_uint4(packed_exponents); + + // store them cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ", + alignof(uint4), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales + // and a 128x4 tile in the output scales. The input scales are in transposed order. + const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + const uint32_t first_oob = (input_scale_inv_cols % 128) / 4; + + if (first_oob == 0) { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG); + } else { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob); + } +} +} // namespace swizzle_kernel_1d +namespace swizzle_kernel_2d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_y; + const uint32_t in_tile_x = out_tile_x; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = sizeof(float); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factor for this warp's 128x128 tile + uint32_t sf = *reinterpret_cast(warp_src); + + // broadcast it to four scaling factors for 1x32 tiles + sf = (sf << 1) | (sf >> 7); + sf = sf | (sf >> 16); + + // broadcast it to sixteen scaling factors for 1x32 tiles + const uint4 sf4{sf, sf, sf, sf}; + + // store it cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf4; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ", + alignof(float), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales + // and a 128x4 tile in the output scales. + const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); +} +} // namespace swizzle_kernel_2d + +void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output, + cudaStream_t stream) { + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "block_scaling_scaling_factor_input"); + CheckInputTensor(*output, "mxfp8_scaling_factor_output"); + + const NVTEScalingMode scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Output tensor must be an mxfp8 tensor"); + + NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 || + input->data.dtype == transformer_engine::DType::kFloat8E5M2, + "Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8"); + NVTE_CHECK(output->data.dtype == input->data.dtype, + "Output data must have the same dtype as input data"); + NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors"); + NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0, + "Output must have E8M0 scaling factors"); + + NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data"); + NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input"); + NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors"); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors"); + + NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix"); + NVTE_CHECK(output->data.shape == input->data.shape, + "Output data must have the same shape as input data"); + NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix"); + NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix"); + + const size_t data_rows = input->data.shape[0]; + const size_t data_cols = input->data.shape[1]; + const size_t input_scale_inv_rows = input->scale_inv.shape[0]; + const size_t input_scale_inv_cols = input->scale_inv.shape[1]; + const size_t output_scale_inv_rows = output->scale_inv.shape[0]; + const size_t output_scale_inv_cols = output->scale_inv.shape[1]; + + NVTE_CHECK(output_scale_inv_rows == DIVUP(data_rows, 128) * 128, + "Expected the output scaling factor matrix to have ", + DIVUP(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows, + " rows instead."); + NVTE_CHECK(output_scale_inv_cols == DIVUP(data_cols, 128) * 4, + "Expected the output scaling factor matrix to have ", + DIVUP(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols, + " columns instead."); + + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_cols, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_cols, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_rows, 4) * 4, + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 4) * 4, + " columns, but it has ", input_scale_inv_cols, " columns instead."); + + swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } else { // scaling_mode == NVTE_BLOCK_SCALING_2D + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_rows, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_cols, 512) * 4, + "Expected the input scaling factor matrix to have ", + DIVUP(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols, + " columns instead."); + + swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } +} + +} // namespace transformer_engine + +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors); + using namespace transformer_engine; + swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input), + convertNVTETensorCheck(output), stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index f49fe239a..35e8b683a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -64,6 +64,10 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; + case NVTE_BLOCK_SCALING_1D: + return "NVTE_BLOCK_SCALING_1D"; + case NVTE_BLOCK_SCALING_2D: + return "NVTE_BLOCK_SCALING_2D"; case NVTE_NVFP4_1D_SCALING: return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index c3f085b87..661cf339a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,6 +14,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -485,6 +486,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index d38bf7996..fcf7a151c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,6 +17,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -529,6 +530,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; size_t num_rows = 1; diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 136459751..15404ad9a 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -104,6 +104,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const bool low_precision = detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); + const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; // Check tensor dimensions const auto& A_shape = A_tensor.shape(); @@ -235,6 +239,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { + // Convert tensors to mxfp8 and swizzle their scaling factors + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb))); + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + if (comm_overlap) { // Prepare extra output tensor TensorWrapper extra_output_tensor; @@ -379,15 +396,6 @@ std::optional> te_general_grouped_gemm( std::vector bias, DType bias_type, bool single_output, std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { - std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, - te_pre_gelu_out_vector, te_workspace_vector; - std::vector te_A_wrappers, te_B_wrappers, wrappers; - std::vector D_vectors; - - auto none = py::none(); - - std::vector single_output_begins; - std::vector single_output_ends; if (single_output && D == std::nullopt) { NVTE_ERROR("not implemented, D should be allocated for single output case."); } @@ -397,6 +405,10 @@ std::optional> te_general_grouped_gemm( output_data_ptr = (*D)[0].data_ptr(); } + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers, + te_pre_gelu_out_wrappers; + std::vector D_vectors; for (size_t i = 0; i < A.size(); i++) { auto te_A = makeTransformerEngineTensor(A[i], none); auto te_B = makeTransformerEngineTensor(B[i], none); @@ -462,29 +474,72 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); - te_A_vector.emplace_back(te_A.data()); - te_B_vector.emplace_back(te_B.data()); - te_D_vector.emplace_back(te_D.data()); - te_bias_vector.emplace_back(te_bias.data()); - te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); - wrappers.emplace_back(std::move(te_D)); - wrappers.emplace_back(std::move(te_bias)); - wrappers.emplace_back(std::move(te_pre_gelu_out)); + te_D_wrappers.emplace_back(std::move(te_D)); + te_bias_wrappers.emplace_back(std::move(te_bias)); + te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; + // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); + + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (transformer_engine::cuda::sm_arch() >= 100) { + // Check if is using FP8 block scaling + bool exists_tensor_using_fp8_block_scaling = false; + bool exists_tensor_not_using_fp8_block_scaling = false; + for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) { + for (const TensorWrapper& tensor : *tensor_wrappers) { + const NVTEScalingMode scaling_mode = tensor.scaling_mode(); + if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) + exists_tensor_using_fp8_block_scaling = true; + else + exists_tensor_not_using_fp8_block_scaling = true; + } + } + if (exists_tensor_using_fp8_block_scaling) { + NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling, + "Either all tensors or no tensor must be FP8 block scaling tensors"); + // Convert tensors to mxfp8 and swizzle their scaling factors + for (TensorWrapper& A_tensor : te_A_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)); + } + for (TensorWrapper& B_tensor : te_B_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)); + } + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + } + + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector; + for (size_t i = 0; i < te_A_wrappers.size(); i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + te_D_vector.emplace_back(te_D_wrappers[i].data()); + te_bias_vector.emplace_back(te_bias_wrappers[i].data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data()); + } + std::vector te_workspace_vector; + std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); - wrappers.emplace_back(std::move(wsp)); + te_workspace_wrappers.emplace_back(std::move(wsp)); } // For now, we only have multi-stream cublas backend. diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 3bb6be715..ffba5b276 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -7,6 +7,7 @@ #include "util.h" #include "common.h" +#include "common/common.h" std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { @@ -177,3 +178,72 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } + +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, + bool rowwise) { + using namespace transformer_engine::pytorch; + using transformer_engine::DIVUP; + + // Check input tensor + const NVTEScalingMode scaling_mode = input.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + + // Get tensor data + NVTEBasicTensor data; + size_t data_flat_first_dim = 1; + size_t data_flat_last_dim = 1; + if (rowwise) { + data = input.get_rowwise_data(); + for (int i = 0; i < data.shape.ndim - 1; ++i) { + data_flat_first_dim *= data.shape.data[i]; + } + data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; + } else { + data = input.get_columnwise_data(); + data_flat_first_dim = data.shape.data[0]; + for (int i = 1; i < data.shape.ndim; ++i) { + data_flat_last_dim *= data.shape.data[i]; + } + } + NVTEShape data_shape{}; + data_shape.data[0] = data_flat_first_dim; + data_shape.data[1] = data_flat_last_dim; + data_shape.ndim = 2; + + // Recreate input tensor with rowwise usage + transformer_engine::TensorWrapper input_cu(scaling_mode); + input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + const NVTEBasicTensor scale_inv = + rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); + input_cu.set_rowwise_scale_inv( + scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); + + // Create output tensor + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + // Output swizzled mxfp8 scaling factor dimensions + const size_t swizzled_scale_inv_first_dim = DIVUP(data_flat_first_dim, 128) * 128; + const size_t swizzled_scale_inv_last_dim = DIVUP(data_flat_last_dim, 128) * 4; + // Allocate memory for swizzled mxfp8 scaling factors + const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + at::Tensor swizzled_scale_inv = at::empty( + std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); + // Set rowwise scaling factors on output + void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + NVTEShape swizzled_scale_inv_shape{}; + swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; + swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; + swizzled_scale_inv_shape.ndim = 2; + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + swizzled_scale_inv_shape); + + // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor + // for it to be kept alive during the GEMM + input = std::move(output_cu); + return swizzled_scale_inv; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 4b2686096..57eee86d2 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -27,4 +27,16 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); +/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. + * + * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid + * transposing it in memory. Due to differences in how block scaling and mxfp8 store data, + * this requires the calling code to treat the output tensor as having been tranposed in this case. + * + * Returns the swizzled scaling factor of the converted mxfp8 tensor. + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, + bool rowwise); + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c001e8e79..51fbb50c4 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1015,12 +1015,8 @@ def _post_process_fp8_blockwise_gather( if out._is_gemm_ready_format(): return out - needs_columnwise_data_transpose = ( - quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() - ) - need_rowwise_scale_transpose = ( - quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() - ) + needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage + need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 # columnwise compact format means doing 128x1 quantization of it diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index a62e10bc5..bfe241f81 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -64,13 +64,12 @@ def check_nvfp4_support() -> Tuple[bool, str]: def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): + if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." + return ( + False, + "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", + ) def check_recipe_support(recipe: Recipe) -> None: From 5be81251d7e1c7b3d897dc2bd376901a83355f90 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:05:00 -0700 Subject: [PATCH 031/141] Fix bug where CUTLASS kernel was not being compiled for SM90a (#2235) Signed-off-by: Tim Moon --- transformer_engine/common/CMakeLists.txt | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 92b57897d..e6be47686 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -155,18 +155,12 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() - -# Disable debug build for cutlass due to hang. -set_source_files_properties("gemm/cutlass_grouped_gemm.cu" PROPERTIES COMPILE_FLAGS "-g0") +# CUTLASS kernels require SM90a and cause hang in debug build +set_property( + SOURCE gemm/cutlass_grouped_gemm.cu + APPEND + PROPERTY + COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") # Configure dependencies target_link_libraries(transformer_engine PUBLIC From 08779fd876562fb5d6d052e03d7fc0a5a91e1585 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 3 Oct 2025 18:14:59 -0700 Subject: [PATCH 032/141] Fix FP8 current scaling attention logic (#2234) * Fix in FP8 attention selection logic Signed-off-by: Kirthi Shankar Sivamani * Improve logic Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .../dot_product_attention/dot_product_attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index a19d08ae5..88e28e3d8 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -597,9 +597,10 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: ] fp8_recipe_dpa = fake_recipes[1] fp8_recipes = fake_recipes - elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe in ( - "", - "Float8CurrentScaling", + elif ( + fp8_recipe.float8_current_scaling() + and _dpa_fp8_recipe in ("", "Float8CurrentScaling") + and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) ): # use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe From 7e45be73bb8d513abe8785ee078ac88719bcd9f1 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Sun, 5 Oct 2025 16:48:27 -0700 Subject: [PATCH 033/141] Added the NVFP4 section to the low precision training tutorial (#2237) * Added the NVFP4 part to the low precision tutorial Signed-off-by: Przemek Tredak * Added the runtime results Signed-off-by: Przemek Tredak * Update docs/examples/fp8_primer.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Signed-off-by: Kirthi Shankar Sivamani * Update docs/examples/fp8_primer.ipynb Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Przemek Tredak Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/examples/FP4_format.png | Bin 0 -> 50946 bytes docs/examples/FP4_linear.png | Bin 0 -> 54101 bytes docs/examples/fp8_primer.ipynb | 168 +++++++++++++++++++++++++-------- 3 files changed, 129 insertions(+), 39 deletions(-) create mode 100644 docs/examples/FP4_format.png create mode 100644 docs/examples/FP4_linear.png diff --git a/docs/examples/FP4_format.png b/docs/examples/FP4_format.png new file mode 100644 index 0000000000000000000000000000000000000000..8c54c33793db6d2bc0d00a75114ffff5996d5ef6 GIT binary patch literal 50946 zcmeFad03KNyZ_x>cavp5<7EM+LWjfiBV2NwVLy|BdA(OFO5Pb%srQ{XU#&MsD{wn^Wwa@%xjl zJLF=v1{}{h^=^}??#5WB#W(XlUESFrerkLGXX|3-bR&5|U(xZ)=V% zFfQ9LUE8OnQoC((HcJTCCn5}nG2$^!j}MCAIcqd;{-QO9|NKjF;dR1J^k$~C_S#=x zc)a6q)Zo#-yi@W!4aV1`(_3rZ85Da;j|IWW0lf0^w z?DOZ@zyG8tJKxk-%du|N@n2tfTzh5Jub;Whef#047SF<0PyRg5_n#8amHgWJdDndx z?L^DfhsH_4rQ2ffMiXySw(x7*EcP^pLytgqL4x?AL+=a93VMid|F=x^!t8!d)ry~@q_U@JEwr# zh#}2vGG~JM@gfE5hk~;MW*#OOXJNeB#&ApPL>sa?@=p7_Qm*q$)eH7#DOC&-AyA@- zgQ7f%l&@qESGl)2wc^g5uy=cRf3p^pR&{5%EMy#_H=Gf8s3d}@!txIefG=+fVsw!s z4QIbN^NJxB8+V`kR+%y&qrg1V8+k9=ulE%imUB;Eg?ovYath~&7gY~Hx<*CRv0$SB zJ?D2jhlg?7n~y3QUWf7IN^xoWLGQOu;5c*-b$iPe|8hf9{K`L&;$b%LXu_-?@^tRL z8`7#{$LwS8^=Sr;T2tU`0t-hJ77NFMi4=~A3LsYz{HdrzKRx$`{N8< z1jN1&8K=wFGp{tN*se!SY*Nk+FBN`XYWPwtVv9?JRp$zKTGdVaBjWl`Yg4UN$2jlI zl6My*B~j^KPD_s4+eWi?XS#-sMWYA{l9s}%Fp9iyY(};)L_scx4U2Yf2Xxf^#_|oRp9FGF> z-X^#8$mcSPSrod6TjG_|*#AjJ`^)}^@(x+NH%k?I*+%H)`Suvk&Mf;PjaZJ<%0w|? zWtZ_{%mXis1-?+sK8}8l>?#jYd=hf)iD=z(1oxD_YtSl^54J<*IELa(s~&7hA&yw(KON%g|LV(~9?lt}a?6;EP5OaKB>o`&zRH?UlZ+`^K=>laidVGY2HJI2 zVw%YSltFBALQ$@>90cS6R4T(Ip2$MwCmNrD+@@Y8tlPih<5MSL6LGeQk#$5(Jne5? z3DXGV#Q7fJr(V`Qx!TYAaFcA$*yfjgoAE@kfQZsJs)oOTZfvbXGor19@@183byZX( z%{#`^(5lVOxrUR*nyULS3d4n+$!kaCkp28%2r$Sa(uG;B`MlWv?Q0xgb9kG`dM`6A zC!dFZD4sSwG7&sXX&4o&2NL|3VtrU%3B)m_WH2-JbN5STrm@g0?H=!qsK#*irrgyJ zM})DK9`RaXu6{FLhtp|Oc_sYi+5CO>)l=p2FaeiUo+8rXk3Qw|viI{NXr6FOd18jE zf!Ua}0p1V;8fpw53aqa7sTJMJw8i18EXu=lf^)JrupL90=z8C}dmGCAcBA4vrOv;i zrSz~6_@!2jyasJP_49-vr)r0Do5k#gcb6BHuD@Ktjwgs583`u%VXtxGd%AcAU*w!T z%}&-zn5B*kPUzj5?JOzo&QU?j>%A02%CA=_v##D-eb6aMbR#Vx)m}lH+-$!)pY2Ia4lxAU*{(#s+ z3l>~7WzK`Wnq7n`EUA1!U#eCPw|5gR{g|l0-2cHVZpq~>?@w_MpGwb8kBLU`EK}cR zFG$i2`yjD?CU72!+aaQdr|1EQiG>FYoXq%8nA!W6ri|e3CC+ zy`W%OJfV|)3{Qu**#3bGZY&EcPkK1s$@ZO8w!K&W$?j;uj5UY8djP6j5+5FSGOGAI z8-}F2;CanLj;pv_%@JF~ExIuNv4%?2&3konRE~I~<3!&Y7A>;pt8tfMbXZ@Mvcrrh zqpu@98g75z zuy`sUoX4BrcV|ca!1xge;loTaz_KSf|SP9ViZw8ay~@FzuM;UDm` zBqn>+1WtUjXaFTZ+<(xy{0wAqn=yUTA)+nUEVTeh_Y{`0;dFuT)Yo=l^*(nCH8dx; za9W%~U+55V(PEt5n6g|OzNEa}4wIO5e1622+cGXZ0-jS}Oi$-a-R}{?A;IDj7rK6- zYfk4(S4x1=F$R%`nl1U#Csq@EUBN-Ukap~)QHk|0r%UJ)NvrDG3U{K5n;nEtoWyb; z@M8aIS1|)g7=u5R6^hRMoF?c+4YbP>C85Ve>MWc}i@{}Z?aUOk*a%yWitY%U<7-S4 z&jqA(6+3#7oC0H7vJHt^nAtx*_91cb5%0~b>y{po3;J{xZ>vt1adf>GzT!cM?)%jm zS#oOC>iI8KGFM>0Vz5ipry+XEr8jz=kwwH?2_w}v(9w?TNfR<7K_q9d?CDx%@5yK+ z9=V4!wll`RiR_gvtYoXmHDpy}prBD|L(bqKWm^@DpSBARb01$eIX&KglZ8xD!oF%3 zbD?5S8Q1UtmeNce@pnb=j6!v}+;cqixa{hPc=q}iyhclNnq1pp@Xe-8oY@JDc0Nv^l-@I#XUlw*_26_1z(Q(~ zzUO9|#jASX1IqWn;*%+@2=8}YIRKs2|GZK0L(GE2?!$P8B_s6)+`ahmcd~+F2XVv= zs-cbchZY9kV_AC%9b&!)tQ&Y2*N|1GC7hd8 zr7`TQvr*(o!J^72XTr*DJ)M-M_uIgVch%ws+O*7)8Cj&%tHsN%=FC&l`dpGs_X}>* z=P|=jUg+@gH|8?VS09|t)D0pDycg1n^dTxouf}zInoIxf)I7Uyp*-w~YU;{9Aury)V03I8O=d7@RzL^p{0><(X^Jl-uczDw3}M*_i*YHTx@H zT5pWXMrRizUaY{C1m>5q8`g3s@>((-AFqh3VmD~_G^M+YywzkhtHktr-SX>cIFR=> z7S6lC?a1%iWwj~#@c#LwcKOP1M010SP0Va(ZG%gT>cH0$4Ky4)ii<-oBTjTMNonDq z9Q_zM!R|=;Z5HPn;Q5dp*JNA_XIhqDSYo6szn|;^XB|*n({S_}vlAD^>}+XpG1QzO zQDc}vxaIl}rDYus-Q+B@L?kS;Y`HI87`h?nH%sdM&EsJ2Q>jlbOn*R#@Ju`xtX)c? z?wUGb>U(i~W!Lr-oYTLbIG|}-G;}+=50-q_P*$1Xv5VF)xdplg%yc0d<4?e!|!*&()lRC|Faf+SV zQHq`5c$$ooEe~SsG~cvn%z``W>5j_Ytkgw~k!JXB2>U}xt1M&*W9>8kje61zY2Z7wp2A5A|N?+yCi5-nW_-@CnJ9%zp;JV>*Y)OKH*F$Vhi zF9+}}BuK-vcEjV^QIR{!-Km&BH?*1$A zG^A$#{BA%aWI;dY277L>=Vtc*u-nLz^~h0lCyeqGMeaEu^~shs!Yc2lBlDd1yTz`O z-?#s>+9iUu>5wR{^syEj`d|G2f;QWYI% z_$^Vba!d=}_oTes=LHT|Peby@xnyKj!=06<`*oYPonRBi5%`HlRlmW;Bt16j<`VME za5i*uxD9uq=FUkTX;++*TJvYqMV7nKdax;Sv9^<&nSVv=j#Ij1Pglb8hk?bTitZ@& zli)s-YYBPpf)Iy$S|{b($kh*T-O}=;)>er;bdXna&X?3W<5c|0{E&a?DmjB*TRN`S zg<48Bj@#@NDASvC>cMHH6Hw8Z5cVa>6}5rXVe#lZbd%Pnu`2cv%5-BPY`Xg_jIK@0 z`2NPgy=DcqC9i}oc<-U|^1iTMm34mIvUhrfbVol&Euy#EYFR&$Mr!)82d4?Ow5uJb zwGjL%hXAm=dBGy^Xju36(v(*LIdTEDYq8#)yM}*v97Hb=-t$VOE&`)QM1BbsjBWs^ zNO7MMc19`PEOx!uVeTj~%<{c!sU2M}d$Y~pC7b$jok-8XooE|xEuF3vZH7KFu9qKj zIDWX$72+D&Uhh+W?&HtDWoVqjnb^|ZhEHc-7)H!a&i6kpP_o>KCWeZR>v{JmVf_R! zM^14eEB-7LMg3MSr^jTP4fXJB^*u3$0m&_*bPOf?%(b;Di!Paa7To&E57Ov1QFisE z;CSS>9+S+^!34Kq1O^(37=Nn6O;9j^Cxa)M3@`}Oy6DBu^VplzbH4$|H-)Nvls;_?e;)mDt9 z^N=R{+TP`zfp`fgam*g52K1bD)9Zs5#fD}>v!-WvRp{iFjh*{uR~t;9&-ZdGJY;*#mA$+Dm|~JN-jAZde2dx9^J;r+ zsYOeMkA?V3aZA1+I*MPdQ$@eK@4R;PETkqoczWo-p>U?OqX8w1ZDavnhf4`3BXEHXEKAFOG@o@l zMeD$9=t@ktl9o0GG=IjIvE9(YjU>*lfHk;Q6^ap)QWR!JzeZUK5{LfMX5{KsklZ~% z?!vRH@wS%raSC*?s4}Q9X0zryEr5i+4RdaSK7AElAcz)#k?6B6MviTUGv+TP`G3~P z3cDpu8Gy8?nSj|+kyf`G?GAlZ_I#H`p{}`dYnn9Wx6VqpWTl`cRaw?MJq7gI?cy@} zJ(j6|B-G`h5_Wfi29yRNR{2~fT*}tw;xk!Cx8Jd6N~fpWFIR$56wF44PCGJVJGFZn z5GGSZsAycNuEo^-jVy_kS#~=_1X>B6SJanYb#k!{yGnjMzwVA_KZ@33y?0ipj<58&Q5aQ%BQk8 zGdfCLCZ`|uPh^oM^f8IunJ?AZp=$J-;8!|!=SovX^-+TYSY>(ba$;l8#x`5Yix&M3 zkivtu8;0NQos}&7oOekn;`mj+mEy$JLwug;a;0QuC(>l6rjtS3Tcz|u-PkJ>=Xy(s zbMB$~a`+DH!QD#OfSv2JPxXFH?S||>l7_Q41#ro|&$z92OAm z0XFi5LN%{r%90cL?*gA3M90J0YC+e(pC^denijL$WmFH&^`(@0m=U+-XN7sQ92Yb~ zS6!Uir9$#G*GW-9FC^Y^kiX8W9erIZqD$(YDoNd**#^Dp7bRCqwK8_) ztMC7v$B_H`rIC0}cu(=i!g%*>X*2HF>}rm!*a0Db9G{8$em&qRjLAULqPVn!`JGL- zn=DMX>U(abuX7yk86A8pb!f&F_Ey=G^2N<$Tpyj924>NUN?Q#X(MS&QrA5TWyciwd zt9D2W-woOE-QXAg7G80xdDZ9^MSQW@P=8iWQ|Y#1G4fc*8E9N)#M?!(hc3pe ziQBhh$CK<>W&Q6cEYf}vt+DlanSuB6BWVftyOgkdRI_7PEEc-oJoVzuCA^+Ec4!0E zH==V#=KeP&l^0<5%|?TK@|+;ivW^CpdKW(^rh{kis?GaqlhqtO>)=6or&1wiDeBZ9 za}K-z(o|GkBl=j6E$xPZii=0N}n!_rOw8nKHc*f zM=fi|7+gVTUT#dyGGR5Bfyo%s71Um;(34UkHZOj)QT%S}lvX;Yy#PjA8p~>a50^iX zq}Wmiu6f5pjLKO$8@)IXTC#!EN57sU^->Bsx?6AN{<=e`yQBHJE|6$1XBl;;izzN) zm%e{item5O4uVRIrKSEGg6w7(=`}q`P;{gM^?XmbZQ1eRPhL$FEC&Lc4sa4 zoI5%cZ`*C~AQQ=G+?Fn+PaM>2pdB1jF9i3OMDEMIV2}Le=N|c9`=SE>clSdilPBvh zdqVQuzIP5XBrJm;46XT_7T^c5nF9`UbY~8w{f8*60YKC3CSz{t+R){!G)@>n9|@rd zVD5<3PoRf9(?TpP;y4`6r3CJ`_U*>^U!G2%cE(1!E+JaB9Wpz*_xC2aimqpi#T;V=e}^j7|rtG{6PdDoYN z1>*+Dm7V`^xD~AK^cYR;|Dc)K1(?3=-SiJMD$+19`Gvc$c@Dt4MHDkx-P!~VceY; zoWsc?ENt5Fl8$Hepg1t^=;-7mRNA1M!>z;uxDo#ZO5o4Dg znkcBZrT9fxd-{+Lozb(kfTHBYB2CuQ&4Iq(;-4%uh_G3DLgj>=L1kCns z{kCuuoPDnj#>!~xyvL~R-$D8CrN)ngXkvA$rTa}i3eXCE-ahc&b z1n3KI{l>j>!HCr=kS~itRe49o$$9V)lTF{>xL*eI$&?w*Ktk=pxrN+XBOf2$71JVa z)7Zaa%;gKGI2`h3uZ#;#DH^!HNdvJ<)oV(3R)fvIpI+AdOUs|xLXhIOliW^K0}SJc zyQ%LaV@JwXl~xc>W{Xt@HNbW_dtc5AU-JWiST%Nj*5TeB&MuGEIQX8=$5N_pi!ZP@ z9ms<{>pS~J$MunaTv2s>Yq~*At6+Lq-|jigVYw1lyN!6!qqx5tPU!d+c(P6 z6Yr#I1V;OKBH2Pm6Ll0J?1*Ynby@6hk8lHDQ-=jB?iGP{23iD$BD6cx*of_<@H=hwT| zgq-i$TAq)r=<;-#I%4XZP`E#T>&e{Op^F>wU_NEWMeV}qbUNkL!lre1{mRYndA#-y zpQ*cJ$o|gtx|Oit)U?Q~qNGnedH%2vj4zh@?zFIs@a<#X9mBW_;1h==r&lXU8jirw zRv-LYT;4R4A18vusiIeG-2GBHE9B`Q)$`Nyvsyk~MR(j9FT+Q2Q;GngAqu$+?GopB zEmbto4V_NgU%>6B(Od*TvU+iyviZQmOvvnw$#a@XK*u!2Y14Ou>VN*+V?WzhV6qc^ znyZQ1dh$u_koWgKc`j_B!HCI)%|jFC3I;)dhibFub#}@>fb_MLq~7`J>;w9+D5`{f zPYkJXoA4XE)48E;V#9ml>4$xuN+ia!r3Y1DQ7tVJKt$xg9YD~e&ASkPxKPUT1+wjQ zp|oR_25h{q4%VtOTvmlE=dZ8rGs^L=o*BZ{Z~9D49oew#nKguA5h+ax5BwW8S4liq z8azCFeA*L=ckQtGaij7J-mnq`4A_x=HjrpuU29dJtk#G|{>saB*Dd0l0sVIU_#KNA zJX*b2MG3L;(pEkrDFkvIrr+hcx9B z)M-<0Vdy_q>fxNYITNh6+W-qY!=l?bxY~)z4B~Ltnf;o zQwxw~@Y<{~IO^tGjYE&X^XA9b_ydT`eV%aX06_0M@ZMi#9(FAOFg_=sP2n#lMJcNJ zAO8uYZ468)qk8PCXFac40woON;h#zvGaWc3^Y~rg*wpXe&Eds{P|lSv6A98LahvOT zwO#hOu!?UBEkr4q-=fHpAZoaUZ=PgE$e88duqG;@vd{hrn7zUnhMx+^+iUBl1KadT z=6bCC=~{jc#M}4SPxNO75@ycl@6y`+dhbrO?iCOx>xBRSi`y#cq#)x87poV&t5jzY z+$^y81-qt0{~3dy{sn_4sQrS$L9aEEKiw<~-;bT1`dna@-xI&h$d_E|VOwSA#Y+)) z5##c!d%IF-eFuSd{m81(@p=W6P{9?Qk2#mGYa0kW%|AC)m!Ut>{rGl znlekFNO}ZkGT+DU?VVwlvPfjryVY9}*9gN%1zQ1BKoqLMM9OVNUromiPmLaNSVVn$ zdj$)G(26DexEt-}BZl+5Y)C*MX`+BA^Xq>3CCD|!Dm=5*IhbmrXM5Hu8hJcaqTKiE zDzY!Pa2jysF96RMwg#atOVru3r6-z^jX5BNDqFZZJYNnreX9ovZxSu-N$oI@lS=|} ztZR~W!4t{4fHu?o{P%RL77x>+whdqQ=7)4~ve*&z{-Sqw-gV{EcDKITD^zrwA5;=4 zG;s10P2)GVRjNC1jRYS#^a4q9AKj#RVhnUUTLrJt^7s!GD&K=QAZRIv&|^l+w4k$R zzdgv4see>sOyPj|b^Mc?dcpH`w(4-C8JCx@B;Q;F!W1wd_*;*?kQ|Q$W{M*s!+w>U_kK znZncugfzQ`c4t?J5u|kEUo_T|q&_b|m@_$pU;k8IfB$I%Ak3<@184r>a3*LctScpF z`!8DUW!5_nM-{J0F#1)(z3i$jkapX<#>W0t*u6~g3<%F-kDgk6``5!C>;W40Lg?uJ zLkn*MG`@VV&ia2i{6DGT*Fij?pC*(SkbClV3%fl`=)nWa;GrEu>UxE(S*_LE=@&Ji zVY!&-tL^P3W14TZ{06pIVXyY$r7+nB<+iDB&nuvI)2o%`K~gMNwfYUlYT&$Qzy?!!RK`M|I2ju z*UCQ1{!I%om+qWO5E|KS*6EGR8NaT zW!Rxl6p9<$P#7X$8oIa}-N_5c9F%^BKp?oe8i@BLm7}{r8vQ*r=gLov9{gW2@qq$c zYK!}DHnJMz@JVRmP|&Fzx7Ym}Bz;+Nwdw=H>J#`?+d%dW*zE+ALMTdaFuwQ7y^O z-t&gG$M-oklz#KODD4|Mhgl7c&;f6>j5pxDWwYE;oN@7JYcScY2^)A|(gOtJOptYtHImZ5I*cHq z+FB3xuXL;quW*KuJzz6~DRyzqYgN~sD5&i+yc^!y=*u9fZm9bWdW`K2h=-W|L-z!W|tquulJJ$Wv|IrZ}CLk3f6nEK+L>l(F{n_ zD`Km7;wawlOX;2Crde>#eXY`(i*#)!?e;I?&%;p`DkkJ6_v2> z7a!q{@)9#jwIwWZl$6?b|3&F{GR}9}@HPX8ybnl;y^5bzGvaFmE$Rzl{Y@HLvlTC` zJ&XVl1TVuC!b)u|>F}@FJU!gLJuV@atZK6DmTcR)^EM!9z4;`+m+erb87tZF+W_(` z#F2U7Ka-c=ZpwN11SsYE79oWUSBf5ernRS{#D&mPSex92r6G9K@`t+M#v$Sf4U-2~ z$!hJN4ORNE^_lc_fMEmnF|y(BXkoA4luz>Otm4i<(wTe64CKi0LxO}qlf=4{sX8Bi z=!6aIZtcn6UEt_KzNk%`xf`=-CST|Yz~@`bq*XD_kfvYpI5E&irYz~mvZba#;7A7n z9mc9&ZvDf`(2k9GH%^S!75Z8sI~^H*G63^X)k!{SCCfTWElUR>NeidF`yn_VQ&>IZ z2x&5j0WK3ts#!r)kf_b4^L}7=ob<)jmQja+I$CsVC%Off!Ym&6Xu+R9AQ4aECv_Re zfSCIR;wOuwjL#QUiG(kUO?}6!yHoALYjy|+6X%y4{5AsUxDbBDA*s6$fyOn@6lDf| zY+qbm9EGZzHuF_WMiUal5F0HWk)akWiLT6mnJB)K@4t(MQ1p}3cW9~&=inbv!A4P<77wFdzvVF{EP=^+LrLYG_7|KXvyK8Dri0Q zcRmjVp@7*1AaK+b?2^K(N@i0=x}^19yFyRr3l@1GqWm~OlTOe{>^5h)8~`kTG!{TO z+idM4Tr8cGh{3($pcRhofKyzJGPit5Xi?8TzE&0R^UO+1QAn=@px5j~3$>q6HPLk! zZBU=U z;ov*8JVDTlirRpBP^E0^)sN7rw9nbBvh)%V#`6ZP12b*xAyyIW)5B$zlMUcxu5Kyk z0AI%RmTF5x-?o97ord2W!z06lL{C?vVJnDLtB0v?5XrJ;d%0lh3rq=DqmBCVaC~sfm)y05lLBG412dwfuIba{mQ2A!2Q(HlIFE9q(Q9X;m?gt9bn_B#qR7@ zwmy_&Av|;sW><6A)c2kQ+vEB-%vSxPE%%}VsraDgeGnk6@!H)K2<3N|eQksoxr};2a1}uMYWlzM+XsFWdrY)HZd^5ODo&Od#;9Sq9q& znQHz%%cXu#Ak2j81_ZP`O9Sxwy13dKeyjBP>dm@EbP_Ys6_xN22*i~Wj>ac%+w)yg z-lFnTQf_nsX!y$r8;`C&p>x7&Svy3d@u{wRcYFjJyro_}38>EfeU!7JSf4oECAHYF zcjXy+;nUxMD7{$xhljzZ0xLZ`YJdv5QVArjj-M*Zwg(mxVC&!ZBJE3hO96<9U|*=_ zv;;K%hb(>hRkI!q(7Z0g-w!><)teh$xUoupe?zkP>)P*`Ts;B(EUr&45mzd-m-0Qf z>j=L^Zw#x9oH@mE&!T?Q!}H!K`!yedVsfMmpKz|mr>)8fR#MKv>yz6dy(Tei=lITQoCAr07P^lY)0s8DogW`xdwMXzkiP&5c zc|c5Cd6D0HWhlD*mMNQcB`Rp9n#@m`hjPTV#v1T1`f%mb`$AU8sF#^(5Xx*N;e~DL zB3{|04&BJ+2jDsWQ;`1tryzayFM@R4oW>tr8Qp+hdf!6nMBr2NZ6^XJO}8Bj+-Fu% zJ2cSXw`zO5Y1FDfY-iC??r5wepLzzQF?0Vp$1?e&Cb~y0PCHF8#;dKd5+qlETgNRqqfgzIzJ@qbQ|x$O8kD}p)Sho8O=b6g)p9_G0I9M_+-c#+ue%vrq5`RUA= z7yY+6Mg2Fh{C_Pe*Hghb-6%t$+N!5=HOMx0&K_2Ir_i}(DWOC%>yT0Cx87qf?F;M}PD@0GP{aNg8pkWDg7f zs_CZqH*&iu4tV(fR~osfoXCcM;c?MTB@~-k6*Ph87mOb^V+v@zYLc4fPl>$HUlDWTw||zAn_Xve@aY5bgaZA$r?`A) zpV1;9`bR{aRslQm(WK9NMgLMrp7^hXZo8fcxZ}+RWl<-Yd(#6qamqIVp$r#} zt*LMa94;Ajw(FCR;|*d`#6I!ec!gPn6eT+g zSc9vgy*HOGWr@AsACgT!8bz)2!^foaFX;{4HDb9|hXBzhQk*SOjwW_*EZ1A!&_!lrVq$){45 zkaQ^qz2b8)x(2jX^t+cTEzl+I{ zKPkDByE{ru&`4pP20CioO(us(OH8$K*Hp4yi@NN~R8%t?O}5kOi3v7H;TRbqH7083rF6Zi2c&z?PzO_;fsh#XBHDsfc?s_mDgz4sWe*5oF|^?1bgCDMt-wi!PHMQx0QIkT;| z?5ly4+*uu5)8VRzbL@^hAo#=6u5c|ORPM_N>qaR>nYQ@WEPV>rm;8}kYp&0+PCpbn z!7#_J`=Maxb9T3Y?^oK+Zx6evy=6-aV6-Y|zI^`HXZY;X${>!A1K6+z`eX3T%`DI# z)-4zDZIt(wt*iG9cFD8`bpw9X2#_CaQ$d$lF5E{Cv6 z-gcR=FT&MsZH>PYpSS609gP`vq$kR!l)4==s4VB-{qhMta{q=n^KAH71q@g=lL7hU zQ_Fxm6UE$M#@(R1V4ZB&bvH{$kTsB!_otpS)XV~AlZ;wT^jM*t-bdy1?;(eBl1WES zQrkQ36t@RtRXMUhzu|s}tA?Wd$-e7$Wcp%IF6+N0m%09@a`IP5sD}U}-+dQkV5JB5 z`dNr3QhVMX!hZ>dv2R-RNF>)L3K1t*M7de4Hf7lG3jeh^_68ZSUk|#dS~xmvL2C7` z3DoGRF?4ZT8f);HHc6B!6yiV-bfca)!OQk?*FwExKzS~{KL$k#992z+f$BUMLGlr9 zZv?h0_EjO{C8*{CWR%UgO6_-1J?n3xdd-(UI|rCRK`g}Heyk~1Cviq`Xh^Qy z<_tL=;D^;fK}??nwNhRixkH95qEr%zvjO<{xaJxg@PcWf-JU4Bz|q$m$Ry ztSy-6rj6`ZE$?%fRXsXY6zAltoFOfi8sJ-*DxHvcr>F^@AV&D??@lc{-G7mrt1P-C z#C&RL-Jg_lKHm}&JjHsc;j+IzVxq<*&1LIPzc(MFJFCZf7Y|$>NHi;I62M0d?_UCo z2iR1*31tj~zE;+eMtXs9F!u77w+;bnq78!_CZ-JwOy0@B;(zpUU3|C1hl_7_=z!+Q zLDnJgK{@0MNvw<=&hKbU-t=Q6{3cEWrVAI=`2$F4*dugpJfccmn zGOx23c*Z6FM;N>jVDREtJ%GU(ZKG)bgHP@<{)@}$cKE+G7X_Eq-2P^6fAinAznOEw zowFSd{a@h%gs-lR-C7anN2^jP^XX!6xU*4yY{~Nh#d{ z?1(s_FUu;^hFO}!*1K|xP<1IVCDX>s!i)X@T9%?s_((op`St5|-sTJGT32FJ9CgR2 zkz|W(hr&}Ab}wjO)Zmht#pLYW4!1clVYhBO`~WIG2TXvma*_+S?+s_$X^gAS#93&0 zaUAV|F!a`;>)Y6w_fqG%ga`+;Rc~!xnHloI9XOqnt{qlf0i4k3JAUR5Hn{#&GqZp7QE@EK$B|-@;#64seI2ck zh7Z?p&p4I^i!ae*St!m`Q+0FU#GPmBpQ$Q1`?yvFLLTEQA>^*NBJS<|DVX*qkzmk$~izl zfOR(e_|$P;iRm=zJV$#fPhPC2A4%G>Cx(I(q-S^~#f5GSScamTlRvM$c$< z6lR9SECJmnP8&e zD(hVJ5YcWE&4SBDKvo!pK2e*d%h&{dS!Qxt=~%EY^1bNFyCf_U8K>rmsWm1H-@-*` z!B=<7Xt=r?^51w0roA-KI@yF{2l3u_fE%v1fA2*cEv!(INdX8RuPcbu;U-C~D0aqo zNvb)5x`Y^D)n2M%zIW)8l|}j3n7`^<9Lmi(u;l;@tW|7pHlH<<2=28EqGqof5*wfYECs zFQrm3zNfQ~m&=#d%OkC&QSU&<1o+*I+$o}nLG(sFHV6JJBhc20z~LT{Ls5i!>R{OQ z{A^Ch0o#V`b4K7ps8k}lNj(G7ubkzXordeqf_oOW;Af2zyCc+BmfDG<-v2!wkff92 zTnc=ohiuyG(W5A=LB$B&TZyBuZ1MHL3;CCL7kEV*+ivz7|4{gH;2D0t&5wS`E!MD- zyGemM!3JHCF%oumEmfdcsur*ncyxG7iMVDU*DjbhBm6)DR_IufH)BY_;1X`O38FWH ztjg!L4M8heeHe8NDS}Yg#*WnT;Ym9p2VsCDAvV$w08`(V9`bv#2;Ao2 zSR%iZR16$N#RpWxgb&}P8Kx7334J0{|BILxrp%AIW~_qzEblPg9Fw>bVZlp22(ERE zB@)`WwlyW28}&PBv(2O}1h{;e*-M>pg0qb7Q1?q!VLN z9Ec$EB#7_C_O2^*_3BN@Nv~*rPNz7mKR4Ka`IU0laqaBAKroL2L|tO=@5luu7v8%F zb=gp3;d^>t;3rukzmW!O_B8bkC0S;hcyq>TicS6Nq)=@&>z-AH?K;v!$4lXsI|%PD z)!Kd=4q^N}&~yE_ub9AB%GhP^B~?}V97gT}$BSiQyf3y&b}}BuNs@>NxDQKeK4BDu zqXr1bN#HyE4anhrw5%Xz{3pJ!N1jM^CuDx@Ug!RqjP4zGtE4IXqki1sw?V|;+)O5Q z(5n(#L{8KBapJcn(+!6hx*n$-k>pOlou#wg`p7-sKePr1&Y7|8Wj2WMOD%6^c%gAuxQKG%4Q7{MlAl%1uKW#95g@$umD# z!KE7$1Um)bmhB{W>+-<*8)f60zZ}Vw<9@*?Q$8QL>C2MRrj(HAFQu|{-d6|Zi0rc$ z+?@QdKI^vpgnuy{pATVdmMkTk?seO@;s z9#;?tWUq9qk1YqaN=enE5QR1h{nD$5q9dRV3@H2H+YrAE>)EnJb!;Y8Cor$rg-~hg z8SDHIC{C=oj232YBjf!W*Lc=S7ARYDlF^$>oBp7#U#i2NdmNH4t zJ8ML(`;rV})Z|@v9dUS7d~R}~z!2B#tXfiMcXoDIJh&S_(AN5V=*IX#Z;*}>(iy}u zG(2x6Lq7}NL5gyV#fUitD_FggZCvma&Km%6KRv-a!-XY8qrTQmJ-?bet0XmMZ+6_X zb4~%~^meX@ARwrBJw6|3F27d&6Tj5D`EFPv0h+-1FrFJT7dsmg&KVY(t1}0 z)?z`$_YWoZC=z)bjwKEs?iyahf8&-H6@2yTIbhYUHL^xQA5y!FLP)23L{59ew~T-c zJ=+J{AB`;Fubi;c{K1+j)wOqO-PQiVO5GbQTs!~ScwHJCR{Ce-^|^Qdt1;;TVH_k| z!1Qv=F0nky(Wbn-Q1i@f@4;BG zzYqRP735@il;1FbP?fIr3*~2~EVO#`RYf$c!OTV0Il0VCSXZ{o`II(i_nV;P+PCp0 z-NW8TdICRRQ*#Q7Qxgs0SU0`QOQF<*#3Cy-)|~O4h$0;8LJCoxq___ZVOc%>-S>*%aD##_9+39RuB5c)aW6 zzOlzZfq#9WK8#eX-Sa{ZTKc}DwS2L5(cS1NY$wRL*OL&H#P7Aho-ppi+CYqfk+kP^ zfD{e)vAXgr3IrJdUGS9z8OEQKbKlYQweozb)Li(PKY|Qz%~0L-sEU#E+av`U*RZF@ zYGC9bNnur~Y<+SBvippHYh;LZMiDQpdl!q1^cPr4AJ+oWm5Eq8-!H?}5DR^5`I1vn zoe^Io#`OLKj%~UjvN}dqd+hT-*tQ_GYze6bPR*T_czp3PIv*Vt7RH>NY{QlNkkTyg zmSfIeO=4XvAkL>x?icC;8KGUJQ`!Ff_(47(zPgz)l)S|>NB;&`mrveL;%y=b~a}P zB&)d!6;Z%*X3Ui^U7D#S*u|g`>005FOY}mg#>?C5RLL^V<=xe|FUr`@2{Ju5#E7#q zGBF{wiASfZq52m|ZGsm;Fj#jsVqhUR6Hl~=+Q7pQzr^I0O^f3WO*G`0Wa$N$sM(i3 zYx_IM#y%4kOzZ(d%y-HU48IM_ka?uuI#u*(El^}TV&9cU+V`){*~=r=UqVomLv5x4 zi#)WQeSH3(_Rch_sY6}EwY0_73W_47Oj@f{5YQqZB1s)83KaxIP$oqsDwBeM5<_aO zQba%tf=s~yLh?gfJvR1O$O3M1}+;B)Q)P+Il*iwa)o**E#pnA6=^y zA@1z#WIx~Yz3)ERVL;~|Cu28>e>~z5z55n(h}Rj=gTb+F(>K=B#T7NaHAeC$^|Dv* z!%%`3EJL4bqEFiSNY>_vu38SX$rG4FGB}Du+|!&2jJx-e)v;9KGTbwQv^^-K!m7_6 zs2lCx(uPT59jc)b~{aZ%V(iC93CCXE2^kW7W|Qcijj|0Pgs!H_K4J)~=?26}jd zq`AC8{KDAPud`R@Hd9>K02EZ=*MWj`lYY>SGLYOUwZz-8hgqI}xWFa`3mQ^w3BUO4 zwLBO~69=YqNfTY7vlB}pXyqY+LvO0o#Ge+dN0s(+*N}G`UV?|DjG|;h2o8FPF>d%C z1LyGaHxLi*_C}#vc=3}Mi-8-TKeE%2SFCXMN;GsoVFxC4lg+`g{>Eao>OOiidR~5F zGe8>FWEhL$^wddP<1kSTq6?N`N_#H)h18z9b3#`dm1PE5P`>DPo^13Q08u1{s4%C1 z%3^VBVJ_lV%u$#Mr)3P`aES(TvA4$h5#=x~veVOovl+ zUq&zS7qJE#Qa;z=L3f3qJIcsEcqd)Q9s%#wu*7ZnNdc8sV2145&_yU^h^OC&ZWN@} zL99y|`k~OQQ<*c3(Cb|%7N&@G*;e+SzUdN~rvX&143A(|Vgr*Q4T9cgp@yqB+>FTy z|9MqC_$fEmFf2};G8L&Dke^ALR0lO(659DDn^YM3cJ%Hfw3@m|CqK5lphAqARZbJ0 zmn~7loV;gwK3AA5zY>BBHs!%x8=86vt;qMc$#$JJeekxoFR;N!O4qXXl_z|V-9$3U z+QOLXH4?yp^&YOa{!P~t$<&grN`>OERg}+cB7<++c`bs~hL7MgF=RR;e0|N)fEB6T zan+#L3e=20_dM#i`SID;X&7Vg+Wyg-f_NC{F?f|{_Ys}Fq7)zsDF_^UIH&5FH~{gr zwrh4eeriux=r0Yr7qe%G)Ji~v;;u3T7P?z(O5fag} z;rXXSk=GwgYQYIGRBcJ_b;Crk3m+ZrT*BchHKX-I`f8IXVWPosvm6%iaJ`aVw|SbG zYvI$p;jTzl9YnRfC`g2#bJ!`qZBaXC@H@_m=?TQy@3{ffNPXzN#PMX?tAoA}V~rQ0 zUg#YY8EdzF>JiD3=FETIZNFEl6e&W_^M^>Gs^|QMi&4scXz5r}UGMI@r`^5d4d7(> zb~ZJ$hu{)Qbt|N@)G6EG`sgYM{Foa6)Lay5t=VKkXo&Rls?#Uu88H#OKAy5JWJCHA z&W|H|u&XE{Csx?@;VUCn=t=wOhJ=cgjV9qHvTVnml=Y(V93R>}x>Wl>XVL=^D#P># z6{2uH1aLrCa+Fo7qj*`YS%vO37$2|$OnUZI!_0t4y!oTBHEQtcttVl)T3cyPutON8 zyZiLV4!9=A`Op7^(mw~3{>4s_595YCeg}0z9<>!a(G|ggO*B`mV^%k>M@Rm^z>_}p z@{vTJ8eOG_B48k#%k)U~BG*FiXjNetf8<0ySA8AP7di6QGb<$OB-lxC5?SX;R=&|` zpC;`L(~P8lPTF@?5d%-YL)x?c8EId2ag+>yo$g01>5H|S+YjlWumwnNO^fVgp&(4U zQ1UV!jRpSkE_2S_oU=FQ?9DlQ|8JeW z0)Y*``kGZ6!Ln>ynK#V?QzdMd5ft{PZ|hf#OV#g2A+E8K*OxlO1#PXnW9;E`VU2Dd5=raSnX`$w90ozpG;S9Obzmf|OMH3S4ze1L2ExgyF5D@<0JC`+xa zH%P7h<%(?zsGs|6GM8moDk`--&Kz&+4hZL<&h&jLX^9?pf4o=ZbC1a>qML%{=i`^Sy5@M&7InLj5t*ObhVvD z>#`U3ztG;JqYpFjkgGK3{n{nJx&Yzc*$^G=%w?RfJw^rYsr2f^qCV6auC-+FHX}#1 zbq%Ke{;ekwoXUt6HX_8`$RDPh83_G(O#zVk1x7S49Y)`3?5oK1)*>)FjY>6Z9rVcK zINp}B@J?s&-fIUOecOb}#Ssi@qj>1-lTust!_?dZsoSsBWi$POQ`-Vgxu8aL@PfGL zkQ8H#ZQg_WWRn^t^wdeD4OYM|u32rHdJIJWo7EJV$s*i#2+6;m=ut#kJ5~Eg7m0H> zx~3VJvGGDY++cbmLD*(fq=Aqcz#!lI^k=H2Q}yNDn!JJb0Gk9UF>PHbk#geChWXp~ zV(*Gn;EvdWF1Zh?f2-dfOhnx*s$NB@5ceYWW*ubdDoKsHGg~ILr1;jUXov_)Iw#xt z<_A5?&s~%DE2ao~(rL`TbK;*{&>K}_tC5tD0RxZS(qAG_M?$$c>Zm^X9SrY>OeWwc zB%kOhtU?wb++n*-)hB4;JQ8E8k-lB_Fx~SA*E2{Bq2({#0>8{uNlBT`s#O{@%SAD? z3K2@CICACRAk{=8UC>+kb)E;j$Je?xVeIoF_~a^`>p-hZV543?yMjoHMav!lM`xBW zfkRVG=^%z1J-D+Yx|`)W-HSX4oC~%rudgl1PCOw^^{7}~AKJHW>YXn#F!Q_*)& zVHv`sbqjvw(XYp5cr@SNaJmY1n5(}YA_Xj7dXh%wu}uIXj%424eeVxKxI&R$3HwCv z6g)N4Y`MhRw=ig$S7`IHrU7%9;Vd00zMoxR(o_@fewPzEFhxgP zZzM=vdN|9i`i!ZIe8n!$HO|CQ*BE5hou2y%t4r9zjpwQ5?i7rI^f1&-Ha$ zD4`CDgxQwx{3|dnC2OC@dezLcExD?S*?IMZBz0?KFJ05GVN#=m@3AkF8EEHgF1hy1 z9(pou0c9aUEo;4})nnf8p7&+lc*dKHM$Mha$fWV4W7ET@Py^A~jD3Kkb6)L379}=s z$ssPuNZt>!EF)CMU$ux0R_f^-C<{#=&_aior$IIf&2}tOmvFyZepzxPYF(N!re0Xi zJBiy$gg@XEsAyQb_!-g?n>@P%iyXJk(Y!OU5#F8(CAwmS8|TVKjT(;q1G4_2MxQ)# zG;yp-1^&Q-MnKn7O;z|q&gYI0g44dfn}^m7xhcHrBIlbY>KBn3@P4VS{PwP}<;q_+ zyKApJ{A@A&${6?vctwSWSyl+WmH~qHZ-9~4>wJQo1;MKEZ=`^uAuEOy+n$Xtx3w02 z=ay=?B`09hmYk$a2QS{Y!PWOy(1xP&@Gu)ZTjL4J=G>c4OKb4yKzN(Qx~te&#DFi-q_I`^#)_Dz**E-xdL)BLe+tBi9m zROe{Aye+~teb4ucf9=(*#Qcc7mm_a^zr3`?zB*s!D&QKK#F zil7M8#OGlTkQ{SlmkP--H}Pq4In~fmn`EpUe$LKk{SM|**cj?$Oxj|H+Vf0e5<|-m z+W8*{s7jlN2zmD7Q=rvjWL@@$g=nofBE@tT7f+WZPtsulhXFeUFoUQm*GtwUo?(g3 z3p@Qn&-K(yW8f8GT~+x-g;}19`t2VWKX- z>-tH|z62Pm73BY9y2%zQzBMp2J>Ljj*0k|8q-?Gq`4Cogdeip@TYZTQc12!c7gdPp zk&Up_uyM!`K3BNME0Na&TJD-f%X69JbGPwxxAAkg@!Gp07a9;Ag~6p^*WA>J{4cjX zR#S|~E3LE-Va0qUL9!?;O^bn%V3m&Gly2#4_D()yWfdVB8#@-KG~Bpy$CdjXWpCX> ze)9aNJrxD-R;uPdBpKdpSdKRVygB-bd8xq0cjtF4Ti=`auA6-i%0pfAjI6Bw@ylP? z!GN|r8s#H#jiik@OIXR>zS>5tF~ zxzRD+qcxq|QhB_pjb0_JlWpEgczAWZ;qiCx8I^~R5hX}&OW^ZQv7ii|pUZ98d_gqd z0T(ylq3((b@xLRYY0f8~^U32~>R=^$&L{s5@yYj04W#=PVfRhZ{BR?U?3_P%)e9$f zCL3LcqEWc9Taeg!IEwMs?^hYht};jVjl+WH&2TG1cxvpEha2BtlhCH1 z7{3ZVVv(0Dv?)lqYhIl^r48H{ghQS(&t26Y*3 zlf4$8$5pX|c&P>-?neA1^;}->?@G1pV=6(otK68y^O1d_5Ri! zO6Cr!TdY@^A}DojyGCE=Y4!G0FKEqEa6=o>Cd%F4s=aoxYY2BO;Cl*wY;-Snio#q} z!ySt8>zw%B{cB9V`$%Mi;+YM3iNGH_hdUw!1w6X{g-sT|;wy!|ZwD3EkPRqCY9N;L zaon(9^3%ticzphSw4_>*ZxM=5ht>R6rtr!=MK}Aeh2^8R15T}#7yLMsIwZ}QH#9`q zonorB*%nshkjNZLh~Pl>B5dj>j@FC*B9Wx3SmFW1LL2O_q`1E+@)xt6Z? zKkL8xZ$Rd#^~jO&9iPqT_&BbL;}nHD+92bbmvyfk{fgXeM1y<|mPc?$Nf1Wh^c>{p zNUI6?6Y%PMY!cQE6y9sOaJu=Pefj4+15Ix)V0vW|vyk?Sj=odOCcfk7GEPu~Z}Hhg z{7)*JhND*IiWs%f_3;PM2yvIU-c7-i|?vJBPmyEx%tE zWA49KK$dq80MJ-gcEe`+fliHEBgJ!@l^mhkuuI<%{63pOjhbj@Zw*{Xg?!rpCW;l7oJ>G-4_oB zjx+elXD`P!gb{v_q9O5mZ7_9dw|WF25=j~`nj?*S-VrG>z5|A8W-2YO{`Qrb+SI)K z?*Wq|IUrpy&Vu1mUUVnwi=$A_LQjX@%R)2UJ#L&a6fmOX+aK89W*O5fT$1vAVM zj)?lImIh#*LkWT6W%$9egOA?ll!XS6mTZC-r!SP-DvlPAbg#J=PCwJIA>*;KM7JeM z5V|_kt@^T}4M1er;DFvdmMQ_+Fb-@}JYt*gsfaJD2nJFTMdv0joeN^2HVJZ^89_js@K z(H%}m8c4#d76K!ASBS08<@*IMtCLwjGf8p9>3vfRXC$Tced~#EB-VrB*bNiFbGE)q zd_a*lV#e3O;|kxmIHTguS3Mn(DYR(D=O^n3=J}+i!RX%&>3=*hluc0uGex{JKlsnl zABoFZW@*1ZauRbe_O?*Hj*y+lsKEDNs!cVxT->JU2Dj0>wnqhS!Bw}IgOqoo9bup{ z6#Ec=WdBTRy_B@@xVMzm7E^DV64>`CPjhpdwxmlsy@7L>}(Wcv{n2vmg>^FOGwFHVw*3e+6z(7P;#=e^RBLQ{JlSHn7$_eCKS^;|6VA5 zdCp+3$vV;UyhK*_6i;Y7#av?NGK}x3ils(qhxNaFO*e&ZOXn4NXr2X3bW`9xXrcPo zOfak+C2fStuidVjRB1isbh3B( z;N-(lA;p3|+zXl7^vjEAZw@=N+naji3MIqw4)dQA6cqWliOV?Src)W7f_Hf^d1bHd zu4x{e$<*%ueeNltEuS6j`fj^aNvn4-B7gX@G?>RH$tOwEqOx`WH)a6S@$oRLd{MZAhm$xwP8z6%KDWA7z9z@*=PH!I!vV8^8bXRy!CwblK0;G z==84bTlNAxM=o7E=6V|E)TTKmI@{Ce;I%CZb6*=S6Dlx}oC`Pk7p1*9UwPgKbLRA4 d+2j*=%n|eMq3uoc-h+R3eQUctf9t^?{s;c{mD~UT literal 0 HcmV?d00001 diff --git a/docs/examples/FP4_linear.png b/docs/examples/FP4_linear.png new file mode 100644 index 0000000000000000000000000000000000000000..2cd4511adc96f04320bf9bddd2d67376877837ec GIT binary patch literal 54101 zcmeFaXIN8P7cQ!Z(t8s{z={wNK>_JtLlIDEks1&I>7jQJX$mR|Mi6O=6e$6uB$N=O zh|;BZ3`p-Kv=BIR!FKEZ?sx9_b?o`Qvu6*( zjq6u5_UxgA@7c4DhWY^bfTw%vK)pg5z%PNbksdgOkUof5-sbDO2~EWf7-vx#EW zelB*AH2T?qbD?x!kG(x}_RObsxjU+_J<{{iKc#;XdY?Y9B)df4FCaFQVE^;R)TXr9 zqOTu$3bB32n}uKT;hwz|l+?_>{y~2dsX(Z`ZO3@*KflSR41}Gf{PVZaL*>*I6)2W@ z9+m$XhnX3TJ@9+jzs5gKO%1n!(>wdCJdkz`eKW-5;~t{f-c7 z{hx2+KlkV%(u`6zg!Mm%3522l^KJi5`ai_?|1(Le9Ox`}a`Vt!>lX2St8`N3A1dK* z$>Cmc%8#qeq32Gk-j-IUu)#OujgDTM8`#N=xo?TAC;wJJ=pW_zv^F{9-pT^5IQguR zTQ2<-DEgFJ*oe<3PW%69Ihtqr4Q@#KSD>6%W?lxtxK$sq{;B=vMu3?(vz<1$VIgAh zt=x#_Po<9L*h^3;e#4IbkjI6-Z=C%PBdZU6h&lF^_}m&MPriV*$^2#_HQcUrofD?Dp_+v>D@6NPT!yTG2vCg}`{la|R*bQIFNqyZ zPn{wK9vIb9#A^5r1;L0}`6qzn{JVDqb`uDUwDJ|m^OO>_htU-NCKV`qpD~3>s^@ZN zy=VHCdM(zr<(4c7 z-m_%lh{;r3R4~_Ej_t)6Y!cUlxcTnZO6yex?AYqS!f0oR+wAs6heM}<-{p&=YmpT- z3QOI{CBCh$jgq^qBqn)&*_m%|+tGrv!cMFGtAWF}N3PZx9(8F-x_GX!+UxnogSvuo z(ul}<&HOtzzP|ssJy^Ux%6nANtUQWLrp&Aa{6VCEi;Sush_tkTICgtK?aJC<7zWn&o+7 zLU?WiGygd}p_%rc^k6+3-fcQnop#K3t4asyNC@ACmCkcyTNC;Ap=xBUhLP3UcYR$y87Mvi~8ZB+1lbRjUHn~gPy%JCg14dxpWL&Q1cU-jSs{{LM|^f z1;i@bY{$|0wb7>Di_4w6$hA;vk6lP9)Qyv9b;hNmAf-7a!f)`V)I;?BpN3!@Lm7%3 z>6!8y+}e*b*v?BR^zaJ0x)A2+9nkECcdvb4+)*{r*@*XW?75Xb44Ub~db@0m>AHAQCU``mWmsnSJ$l%@+b&{1T zQ0}u1W{no|j#s=+-oS_s4R*?fV~a}qs%;N&Uw+pjCA!g~W6^^N#VtYpqsKWiTw^wr$rSWN8N;gpV@xDa)&!+Q4p* zuCfCZSAZn0%=6NzP$77Au@N=0X^((1716cPsD&anNiz$g^|h!jw{`5A&bw-vY1W~` z!9#B!hR#*jSj2C=5-2R|EIRa|bbW?~&(lxl?l|%8?e@@!HW>VXRCwc5x_ezSo|n8Z zBC#Oc(_tA&BBskjlq|Xw(4WQ8>c=|Z5U~$G1I{$oSJ1u)GN&CzS7}ZVudBDqA9K+` zSNYb3_m-0U@VNz|cDY&f??#w+>bSn(lZaph1*5xNIexA z=xOUNgv{T0*ZkR7nbE4$#rF>Wd3uyzprL0)w|VH_URRTLsYD&S`#79b`ha9x$yMg5 z?zOLPaW#J9Lqi3Vujlti4GtwitP694EAeJool>1XXHGZjJ3NoGr<`1U{jO0;1T*KY ze|y0^dGTOCymBn2-)J>uejs3=^ao3#;~vhMna&jbJwWescJ$u36i9qwV`b9tKNLS8 z5VlYniWnI7n=5Rkt2ktx$?Lm5Rac3@yAd42lYlPVLsVOZt_{Z7QwW9Y5??0{b*=S` zo)mF6y6e)eUBT5TFHqIV@3zvJowAr7(KR?*GRuR;PP%2md27h3yrnDqD*3iI$5uv5 z`bUcojvbvsv!U^XRYFIv9ika7i(XPw0ywiSd~1Ov6~e0(S@w4gw6H!A=seWV4xjBX zhfdIxI_<2T;}7uQh&^P#;D0ay)yVQhf{HBW;Dyp*gaW(oGJ*@EE*{Lhy*zZJatA-O z*~yQde$wi#ZeVXls_}MJeCn%f~yu- z;oA55vI5fiF)50DY;_ic6P+#@o1JNJZITIh=8S3Oc4WLt~$K?y=fX{e86~;QF)jS3bFMzmr`Op1ug8vuk(9Tl8l#F%2)F;d}{#3oc ziPkw7#QM$){_o5B{$1v8=i%So1n12E4z9|-O8c+ULg(o(9R62nDXIUz5=Qtx-Ztwa z%O$p_zRuUx*~g}8x}t%ie!ceN`(e3@-czRbjVZiAfGoX*c*Bwq;0+0e*iY@&;f(mQ zp&t4xug*QA2=77S=4WnA%REbNQfpZ7vq7w&dq+pTpJfT_1;5)us63}ECOt|+)Jocy_Fq# z`~~Dm_G|qe$+_eN<}-SxracmAX7FhPIDCqJAY|otT$q~qprImGteB7^M-2}~SZKx2`oH~+2kMI0me>)jCoB5tO&vUt z;W+TsC+t7$-%F812V%^UJcU$*+8jp)UW*ZxfaqeCQ-E|{LwFezRBVM9fVbR?eTYnU z&^W@%fBPsDw$g|Z) z1q84`_j!f@p%#0ecoWQf3rK70njZLdG4RTqD^0wZ(e$v+9}YU`3y(i$+y`IZUoN2H zM7O21K+&mZV8(|8+K=v~;DGi$D(GLa{417!wdH>wBRV%nsO@(fWHZ``Ay>3_8gVrO zFFI#@E92zbFuQ$p@K0}9_p|VyQwG*86=RdC!v7B7>k;3Mou`hxKBLQuF58$y4tR_O zT@B-GkC5J48?N*v&2XJ9nf{bBTRhbveah@OK}E&KRPrRh0So$qi||AUK#siAAVOKK z>!MD%$Oc8cBA82EuL!b&U-OwKNHGa-|DL=4cS4zSIgj#6mF?zCQSOY#X`G z0S^Mfu$GOIp{XiU1!gE-t?&1=rfka(4~r$OvsJdnm+DT<&92oUDpK_26=cowYC>9{ z%M|+@fJ*qMs2tBd_&?VQB38_QJKTvP9~GS14@ zlQ#&a7B(3U^nnP@(4B;Zu{+qqdDkRfDHnnYy^j1!Zg_!Y#xteb0k1inF;mo~pWEU$0nCciw~3qBh*q04v%CHdA+QnVXo;$3f?E*ZCI=|j0z0cGRxMvF7{Pa6g%Oq-wZ-r*#(zQt>{w9Tz%F* zDWcQ#3XY+d;nS|6fycD9aRH)o0O%|IXtMZTib_BQf}N~~(oPUm4AfetGN=`$;pohq z3R;s$H3^z*&-M&GnAT9|5x2hFzCrE%XSHWy_Sn2r;{Dy3l*qqQZBEr^IruaVs)+k?4n=07OO?<uOsf)VGoM7f(J2MB&US6 z+wLP}_O`Z2ec7_}^`aox`5t|*pf=8}zkMJUBgAuPO#Gw)cPhFJB2fAuX-|7p(U=vW z+Ktoja8YeStx!fLD+Fha(o>l$-Boy*xzeQyfI_`~aH5bo48DQ5zk}+|T+#Qj8g5Of z#mSdg-5J=nAcL1j`*9%1y{q z@G|H!#bb14Ys4C@Rk32~hCXb+N(lII5MuGDYGWPyT{?p-q$cJ4Zv?{}O0kztmR(I-YOjm)g9jcGaj&+n3@aRzC^at0(kI2Tn_Q>Y-sip?!KBP#mf}Ld-!Qp1*8!jK&SW;c<;aQ)Rk~~MVm9Z{v)~`xX zGM5ud%Rb4WuB~g3qghDp$%@S zrBE3|Rtc@~M&q_t>5RqY*P=BJ@_!T^4Nt<3S8MTMlVgG!uTh@EEaT*55YL#azMP_c7+J>7rIWc;4_xFVsJ zUM)rQ=!;xvUw(pp(I|Ll`(esk2IwZ1PRIrOX`zxc2tJ!*;y!8G!I**mbZP3uY|&~q z!3~c;&O=%s4CLD!dQDJ2la~}v4x|dXj39O)$NE=;;ka#O!p7U}2;HDc)Oj=*&WcAH zoQ@ye466|Q2>b{>n4xpv!-_~A44tFlUadLE8vp6DPL|Tv(qeXS)7E&J=RS!(tGo~e zSm_4$9KGA1lZE?SCxXqcY-S@^IojK;GWf8XF1?KJk&o^()0p;8^Yh}8Ms{VP4sbn_ zWMucVsKAmr9IVf=m%$WYSR4df@ODLP7S4YlCa^un|CmQEh;y|`UA5g(29}}$LdAbu z$^z2tm1Y*L+C9H0LswmJI15+g*6JPHrY5f|T|Oj+-YwDFPY_Iq zu`)??RrW!2>C7;fz?&qc1BRB>q@bjM`KtMG1|6o5`1vK|=W0pei>1_XG;ZQ~&gx5r zxHs*v3X~E$;({$mWcP*n#UV`F3-qYkXaqwMLt!k}XyBoU!eJ3?Ei9HxhR65}RCV~q$ZCepRy@?ZO!H1r9zXBtwRj=so-Hnk5jwypKC z5MfN8*ekxhLK5lZQAvC^t!{^*TfcM6<>#_9K9hyZm4{t@ z?^}!4b#obUy^dSVIGQZrbQg~_6Vq1pWhd1BY@j-F2di-_lNzq0Ke0tveCx}*<=Zc? z(r@~Ht#cDGFyET!C4K*JY{$@2`XtO=hz$+Hx*n9Jh%~FG;Jx=uW>^HGzh@@cHnBme)&PqBH z?Jf6a^W$;8p*4?m^v$Y~L%kSk>#SVIwJxr+bg4chH`LfUweS|;ZWhxY=~Fg-mEcBD z!PRB(pkrMY7EHzjW;uMFhN}%%qfc)i^ZY(n##ZRNd=s3|Y1WvSc4RW*#cPmJ%pUED zln!ACf;nBBP=BLH;qSo+SKlIhV&48fq!h**qbMKvkq#bl`k6=bol2CZuaD!}X9lj| zknA$wE<5r}2%|0l_Ud$$Z}?Ys|7uf715+hx&A_t?CDY*qHNK(;pH<~N1H>!tWDlO{ z+Z>aL^4dIQcbc>~X2)FEXA`%wbrx>q=e4%RC4!hgcFocHZV z0>Tw?YrEmm3YBpG37e%?E<)MS$7F>zHb>h1?Ai_FF^&6%t>PP%!{SS-FIS+qb9}oV z;7Ao!U868A1Fq+!Mf*tkd-q3bqwU$ZTgj&t<|1(3kXH+fVxLLmN$&w(5uQ=ERJBrw`S;EY>PkOX;U*%**W>Md0cC{!hkR*)BR$ho!j@u*ne0>HH$KvX} zIeV0s3vX9M4o0>xJ*UPO6Yv$49ers90?L^JzU0N=gg~~l?yHW6r&bQ-S+!r$#)>`G z_Bh`6Y4D?hlys@%g3^rBM4^ez*jjb0o7aVu!PT&JVTAPv*&%dTCOcs5v0iq0_C>VS z*$k^tr5mwz!hFP`mEaEssGOUbc#e4MKJEzON@8h-J-NGgO@Ybh+FUB24u|%JQ zwtm2QgE8x)!KPd9NgCZi829E0aSk+ls+i5!h+_tV6_2IEN$g6hQrEdp8F+jd=3{5v z8dKAIr$VoVE&@j17?TmGEbmR!xqXN6d?^3VmW(YP-n`F;Op;NvT)yrT!iK#yKTO{? zbkf}N{Hd3c39i4O&0wS<*cPDonFND)_dfb zoLeB^r~DHJ>k-&E2Jr?L3*16{F@O^G7 zF3-6o?~<#D3bfXP+efqbPj0pNeYVFRieLKbNAFiUoPs3s`5xijc*vl??Wk3-dXtS% z%dewNP$9qv&Pd3&aNwmcM*=F@X?i{&b9%hvR;f`uIgP83Auu{`Xlu=Tk@`&<3usD6^Ddqu6{Ix!@^$M1?Cc|2r z1APU*SX8xDj5ADY&G#ONBfV?WrO~SQ!x#3CrcTb)E$K5(58c%!eQ=iP#BMxD5Z)HA zVu&jVf{`cDtC13ZS)M_#EjPz4G?j%aNEuU};`jvIQpr>{faV-Jj7{$2c@+{&y%Ms< zD5~Wn3aiLuD^%?{7`#|XD~K1zD?GV^16cTekd(#DS(;vScSMsTv5m%N!)VWO2@tcR zzBoj1@f`qX-ir(!CoQs=J~51(xcwf%5!l8+pI#B~9*mPf2;=3(w#geyKC6*}Ri0UG zX-lVT`S)JwuX%Kh0iMq5F~J$5^~P0ZuGn8&s(xbSx(=o+{ma?W6?`_y*S+s@LZLBH zP5Yq6vLW$-X7)N-(+!u9zixff;0<2N?F-0rA7zHOeA%9PdbI8${uC$r>C%Z5+*}GD zUHZ~Z9hB}YK}8b=4|NJBp=C$DGtuAj+w950hlUT?EJRZk>ZGxgFw3;Up$+cCNm3(l zp>W`J=r?F?k%sOhE%I;{q`0noxMCmq;tdIE)JE! z!gW`_5f+&A`-IGcNLgK;CAZh3riuwF9^%*5>o;)SDW&imahYMl;~&IC`Dw;OS5AA)36V>-0#2)0zF)%`;dpKc@*sA#J-mzfX4mIxV=#%t#&9&5G=qj!_=~03oG_jL{};h zmK&_JV(`&Q1{A{_)Z!^6otk-ODI=^A4C`-9?M%Up0J?9w&eD+H`Zu-+W%w$gH_Y-U zxd!b_F3FaaI#^kpjHHJTiF~=4u+q_CB#e{D%{tu0j(^akg4GZY>8HqOv^H*Teup60 z5md&AYfDqU%qzWCx_5n01r;G>Z+a^C(o3^lHk=b42C)AQ`Uj@OfC#q24#cdA1hiz7z&Ut`BfDlzA{Ao2 z9aObz9qAV_4&e!FxJ4pxg>ha=ty#PN%O3yQ?S_G^8A4B(dcw`v$|hg4h}+5K@g>NS z*^)WNo3GE%z!5$!7a_my{nb@mV_|P+;!H?UD3C=kcU~{kka<%t{^P6_QIqmeoPG6# zESB12wv;rm|8r%c*ZXB!L5&a?Zbj<}FS!gs=l%3qAk4JNHqDx#BKH;S(Yni*1?_^S zzP0r_oOVj@l-+Qxe|o3XEolS6PPc`EZSJ8YseTDzVe9tp8nYKDFw*k%^}EgdUg#?c ziv^#i_3f#%2Nxgj4{brw;BXn5^HUYRMhkLFe-}Bk*6etB9na>QtlRMFft5Tzu0bi z6BX2y;HOrSaxi#$eG^YtVuSOYnv<4v64vp`%zco&Kl{wUnn_p%i5s^ zG05c2FL8JH=q9-WwSq3(elPug`7U*G@tR{?gG#`GuY5PeQHGNcKVpGxnNhpcKcLkL>*H=q?|mCS zv@Y+HG^2m?s>P2s>&<(Oe%oV%K`tHcYmWZ<*%|0TUSV8g7WJc{Vfm=R<)xhO0ldAw z;!f*ccswyCr*V==f3~1On?zl$N5VZ5D&%$!QpY8fcH0ABBTKmkQ#yid%!ju@MRf3% zJ8pa)@o@ANR66;J?e5(7SeW*?k(FCr4;q*e7#@5KHN4hzbG}8tLfRV5TDBE1RcSf6WjIa` z->}so;S&1U%hqt#`UM$m1)THR-zo2>t1_m zw>}sByB^iry>Ev&$AAamJ0{yS$c3(7nwD{+hy@ln8DP)hdKYPH%qI*~1FA;8`z11VOH{PRT*_>DhaSHwEUJ2V)rDq1Se+k6q{4 zn*KDo`T~q&kjBBDJIRQY_n zV1x&F(A!JegmHTj?=$DoRvweVy@jFA?TJ!^iq+t->>}edWb)mY-(k$5gEF%vA!k;p zuk+w1Y98H7ScqCa!iV>Qn2p_=D#W}9c5`~3nBXTdz0Jg@JMXTBSF~i#a73h%*T!>} zEvbT!X+3#&$-7s*FZIJ_yY;5y^*duM8I2h^r3xzRxp7@+1XhF5Rhq9(h|0fzrTs}k z4NAJ&g>&qO41WFI#ijSR@+H5{c%MldOj0HK){f?(kY@5@{km+?`5?>JGURvIeuz-J zP{YUfTxv9Wz98F2*rsfKln@F-KhSpiD`uQ;sB426^d$v3HA)FgV~3v;M&H`oVWhS< zmnvWD+jZNOw5rPloo~Eo>VFR-(Z6zAaBC|wuhEK_w=tO!Qg^oGOeSun2j!JtHD{Y! z&6!mknjLvdUTV6q*N(37U1JfDIC)Q%vd>WMvu7P8_6f0#&F_pv^q$=WQNpv`qZ=T+ zY4I_2f@7Skjo!(eeCdUe&x)66^Eu+vrD&tW2g9BEbjS6vVsv@qJk&oyO7(`M@{J0V z+-l$X%ulsKSqx@eB1Sx+{2zF7s*uTfhWCYfIF+ukI9BneP+Lg%Rt>KV`fbnWr^*RW zf2uoM7j*J!&L>O$?esv3H~+yB8H9`_+lQBFE*wEG_ev%w-1hK$bx$PLx8}p9FMULO z{+3h40vFGK-i?urGLX|PL%jyEP%w4w0lT*2kwLJy4B2~Y4zopLg@Zpd1DQw4lZ7as z?5P%tT2zU2ysdF9U9ZTo=7?LMnnd@B(n;-Xy&TU|yq(?h!Iu_nFKd(Gag`uu-p}~* zo}acU4g4DJ>5+s05cdet9XAXBc@QZL!rJ~AMtEqAXF6|Dx|alwmm$5k=0uOUt>lv! zr$d)_Lh3=~(@-M5!8QVU5-s*b>AJzUJKog~Kk!^}-jXVW>3rnzn$D%;#XrfoRtS z?x;-P)EkN49wvTlFbr!Kf#Q!N5-%FQx$0d?jIFONpoTlfal5cY3q$QD>#?=SpPlX#xRs+zc3m{lnM4G=cDC%Nbj zyqaO6?RyaI03mp&=o1NIz#ubp^yWY-wW@K)a79Q$) zmCtR!!vee9amw+i5(aQXJO%H4;=*sCIZo ze+O~xG^t>a0%`!gJ24JzbbBdqVC56&7xw}6?S0h3@`0}#{Au#1y#_KHPwg(w@`7FD z>`(y7m|{?XU~q(zTJ;iu37U~fQSjh&?{Uy?6is>ZDX7=weuJDoLJj|-ZnFoZMx4P* zu@^zXNFV?{)>=@qQ0sFIRN2dj(;fyT%wl_V-lsq{@WD`_2(=P+NbK+6rv)qegG{?O zsLB5+BPo;t%=Du@Dm@u_5>@H|1h!k}dyX3L*))dfJor24Es!*89M$uYq@RTY%GYEFPM3+(9!2(2B(U7DoyK)AjickC zCYhjYXAmm=2R&jY_rX%W2Aa-+?_BTg12t)h84ppgvy7m-1k|0p33MbL1oGQ$(Arg( zzik;(AOrxx!RvGslz(lS4O0cmWaN-+V=d}{s0#`teWATmAl_vXOEkAkW1_NO&K4OMwO z@LLrqhi{WnK(WX`oygDf!LOacn#O}GSHL(r=+DEM2A0g4@Nd9F`ZjLObjIMLS~4hQ zp*#fC?1FmwWeQ3m@YTdwx?MW6+d6d|D$+T{<3&}0%4ItY^%NO`;r|uOzhe1UTmGLl zqGfE{Ajg(-Z`|;t5mY$G-+cMU-Y^}SZJ#ekpF#~|I>5T`dU&~>1y}&84;Q=Qn3(*^ zW}fd>A6lQ$6W-i>>ZWQsVRh}uYO>Jmu*LeRdmP*Gq?@7z)KpT z&Z>*VNIphbR{p)MheA+)*k5xoo(2-jWpg`HUi){;@TCM+5dXSFK@t9|79cfteI&;@ zVV@Apz$;6*Hp1n5W4g<6n(3S}{)V%ILrCA+ZI2P|P^hi0NnC!C^18o7| zs0n4uYaM&Dz3VebymtdZafPebI8jHOq(e+@+u5R&oyD~L){&!TSXS#yPmZPFDTR6? zn2W$|u=W%kezcy{Qk6)75pHL0^A_cV3xD29rqliDo3ECMacz2xI44u%WDU4_GlCX7k0$kT) zWMl=NYu|pxNWAe0(8MKS$iv(IML?!rzYsE55E?B${{fgx-s#&$1sbskd^=2gwc%LW z*22}OGd0Uk%;p%qrVHE4xJR_$^V(ua*=Vf{Jc*W3L|AoL1a zCuy~^^aQPD3OthsL-7VpaKjzFPXwj7!2hTu9)-$Yz9*H27Tsygb$%^o@X=HA_WB7K z#CKWFJ#l~)@xN27^+Hy(N?h16zJqtcThHm5Qh}=5B%ZS4P@z<$v>1_!luoATFtTnK zk+6t;72*Ja;SK+`)gL$FGjezArTjE7+B?hFWOT4%F^J>JbnrtaJk|p8UEwJB7?{+! z;1Awu3gBaaS^t;D^4Ez)Oe-|a3^#xVY_lBv$!joO0InZl0$;Rnn;d!ZQHf2 z3pyX~i06+^x!pSj06-jAHy-@iO}}yT5)Ub;*OpD4{4=Cd1(wxyUvL9l*`X-ruQdBP zU_dem_ARBo91_&I0|0yv-nMoXH26p(-e?X8-K0e8DnH-!TkeH2qx2QMiB3E`Y=v#n z2WYgK@e8kcetWvmY;$>F?3o5{3h6H4pm89w7%@Ba2cz70b?z$_X|w5C5cod&EcE6I z)I^MaJCON9#XKM@b2+LBHIP5;i0+| zS!2KlqO+XN=;o|eW zC=>ZE$vaK_htErh#>+!vB$el^{;&(z_KX#)(bd^8!VAEcu)O4em50ik#64djlOa=( zn5H6~3VpJ@u`wKWcB05lk!^T9V8@`^im3dSLH${_VX^gH%VZ8GIR=;A<>Uee`bE9Z zWzTQ6{_Yly3sCW;GoCXg_nJ@+wNr2R`14Rs_63C1JnUn0iWU`zB&;X|>``c)`8Q_1 zYosC-`!KYiK^h3O#Ljs}kIV!-wz~PS`~qzoj*9DW_kc&I;b)w~Aq%~R9aT*S9*Wjy zLO2(YVrSP&P%nQ?s_^-wZw{aSUTHu#%6aIpR80#%Wn7@s`XJ8SD9OL&kDYiN9UW zJuu?P?zK4)z>|&0!5sMG2sYxIp2|Ec7~a};_)O>zvtK6CulXHCg|pDbu9b>}wCfm(`(5V% zHfl@_6pr};BLrCfMeGp2a{a?C_E9-F8bChlHG5cMV4j7Jx7SjPUhq=fVf$j>+isuy za0z&&%n*uMh?_(nU8PPPxxFipoz{w$o+GqyE6r!w?~rERg=$1u02LR_S-fqC_0`9N zU@pEl)~Y_ad7Vl7_@skrFNH2Gy(7~sCZ^d)TUR~zJ!B7wM}b0JesS6PxB)js*2HBf zLnSddn8P^+0_aAL14WA; zCPv16Z=}ft0OEMAfkG`979uQ@!r|FzsC2Qx&T5AUtV5*tt<&%GZ=4G)>l0#FA-O$# z-Xw`ZGOgF__{@+9O@jQ0^Wh$e_aE6nlMlxOF@tYJo_3(*U=}eTA#>r~MFT!426srV z)>P@4)n2*P~kXR1X29NfbVJ9jlSKrtgOfIm;!_HP}1( zCJZP1mAJE{^MVjB$@AGW}1cjS6 z1Y+f90EV?RpntPr`o+&>INyP!*ZUPMrb^o3^lBN;ib0G93<%nVB4$0Ilkt}hD0J2MzP(%2 zGWYVhJ#G3g6$lsz$QWSPlHYb}x|qB8@|~rGh76=e%#R`Wsnh?D8ECci4ZA|1y~1vQr`jG}&lIV3)-nAd>ehw6yVRXM@x%>iQ*%_RgHp zsrjiO1E{uf$9^cM7fvJY!58o1HKq)`trgU0-jGE6vAI>?m@q~45`kw>NCu$L>tuG& zpaxj7I7Q_^KuOS`M1EY%NBGXIVf3Kmftxw=^JoPKfQ8J|)`D=@DYrfd6kt5dqVc{N z_ezByICegKNzHs+;M`X^tk5f8f@^QrTcQ4C?l+LIPC-o*{H`^(p(*@>SY>1=RiZY{ zw19Kvzd#27>FlKms%&yOa=`<^n2}XHxkh(k_mKE{1;UbJzt1|87fc|CSDX$xiHY%$ zoy=|je4j&9zgY>DPxF>fqC7kFcGS~(M?Dk%x!L-7wJd{NQPci6tw7y$nGX_qdFF68 zSEUlwOfL88VufP&L7Jik_01V_{TAK}P~L4*f7h)LV*DL|;^avrHA1Z+lwW|Hxua~h zeL>aDvo8Y`@i0>XUXO>+!BYA?)u5bz`<|8QmzSyPRDi@kgzlKYkUn838_KJHuwdf3=N(nf~Px zKqloylzM<=C4B-*)P*W4bUDD$(YH=L(qo;)eO8u~pYuKVaJ^9Y((c7`wvdj6fgY2c zt84&}K*ZM>+tQuzjC-mv1{gwpD8z#x#>l9!7`>nV*AW}bacE3w2N*daU^8U?Ev$2Z z5H)TM3%xR{;pI}7b=8GVjwwcbVjR-!v;-A^>1qnKzJR)nJyhK!tI+OO>d-h&KwoZ1 zPW&FGc}HPg3iCvVa~z)ox$97B3eTFLdmrs&9ErWK1oMx~){MP<==73|8nF@D+o^MVBR9&fl zZwa94)Dq7}b-I9pTXL%rtQ7*fay(?cE{3HUOhWDp#QVzr`qk+(M~k zs~jWgvu~SpDF5FiGoz#t!MXYVT>7VB8s@KtiBG(c>%{s;6MYs2K07;*326G=qVM5s zP#(zr&KbDeVVP~kZB7q> zU6J3-Uu*S8pJCh1i9jh0g)z`H|8O}z%6GYe8@;Bkgh?7z`90|;Gf)NdB`mG~_@20# zeH_r&bFqr)Uh~Ly=}>5C)W5@rf1Sav_@>=JTnRbaJ(hn{7*9SH466fIw8grEmbw-d zSclvZCr1Zf>Pt-e`pi8I$Alu>uCRI|>I?SUf0x$;6I+s(R5AB#pi^)bv+E zfq=|p1eTSLTN=bFdbXu|e|GIeLC)@M>fGBA_<6izXI$mCR6o(o)fX;(l1^_7Ty{8H z5@0mfq^KrA0}*ql`yo2`m*JlK*-&`Gv9Hn3UpU*qqt-l}Nq?^N0(&^UxD>h!I@uzj zDT6ZMF3e=#d375(b<1jm`uPk_Xz34jD;CIGP!Hah4h^iBM9&v#UVyKr0b93}9i&L% zBE=sHvUmEaXtjnChNE!%Vmdnb_EJdE0>pVvx)bNIA~mDw0r*P(YDy@f_LFHiod9o1 zy^*$*dSMKhfTHn#4-x<@3eV$&`k;&lC=cMMq9k)>yfgIZjUd-;#o!Gl?CUz@@u@@b ze@gEk@svcqX=S;j@c8*TLcvWpVs?7cCF%G*{%;^>fd`%j1wcAN@k4c1_%~RdSjzuy z6wyp8E+n)SytMm$=_{XCTi_%CI>!ke0n>RJE{$<8-IVUDGth1-CTc=_Ip&{yxpM{s zFImJFZ*+E?xVW_MR+g&JW99?YK#={%_t^ZK_C`URZ><`la>UMCgLYj;ZM|<>1L@Xbwnh z%+b6%mI_hlU!H<|SLY05SJcdD0wCse(T;WtL~AT(GFS&W+4rH=#@5W3GCW8BqK-Sq z(Az)e0Q>|y{O*eV_OAZexpn|r4#Zj9K+66X1$zYcCmC?Whl)!eD$?+F@+E{w=Fm>q z8ST2n_KTUx1SY>?;Q*Ca&1yf)=8Elo)#dNAoYDb2^gz#Ri z`k3VxL&t?yU7G&rbC1Ga+9#yruClO{*FDoiJ3-7@HYAhluParH0Ml_yI7&1-`Xj)H zvo#UKW=6Hj<=oW;1}vdN4`g|utL47m7Zi_s$h0ys{R|KZ5Si)RL3m((C4iEya%&(s zDdf8>L>WRMjo#lEFy=fIC$THwV}Km6tG(7}rCBphXdL(oXL#qto-)ot5+8`}=nX%0QuC z1*g6ZOJR>y2Gk8r-yRb1>-7OZKnK3JxY_}J=OjZ^k1aQ#iS^rBH>nfG(7&}W^MLBe zg({y7hbzcWgK%C!&vbsP`o56ER~a(0%RC<+{Z>4MmYL@PqnvX3CBye%8CMiTnnJ)| z`O$7&)@V0>nOW>{n@`_P5qhk%ZG&K8ph{8Eakbwqaxg9j_8!i`Bu^O?xaQsGlt<#)j%G_Uj0_Dcu%&ohvh=a+Czq^&Ee+Bv=szb^Q}pwFx)AkA699&*M4rE$uPL5bZy5YgZ-cL9&@!OO*r#SL;Xe zWh^P21aD{ZeDwT~x;Bs(Pt2U}pYxT}DiD11?3v^Mb}H%Wa4@(IS^w_#?)EAIAG}&z zDKQwhEC(V^0aq|kkwvx7V=S&NOZEvVYZ|vYKnP@0r_Kf!x?k3Qs4@MmrD$>F?ck>Q z#m)K3N+8o-M?Eu?c7;jE5zZBGN5y`xl4aU!LrF6BzDsW_zbo?@S{eq*jmWFZ5!?3H z*c5h~6Cx8fe?C%<4{ev$j(ZIs(~7HXNxrF1q%T>YRW-Pv4&7!7W=9R~=Gp6i2q^NQ z>oGvf*eLo}cm=oQ570*$d0KW^XX7r7ZJp?UE=&=1Fs=0GNf8<*_ws$0=;iD{QS6As zml;mbg?=N{es8E(<>q{df(;ol?G`S#LJjuP)LLBad(CCOjb2ZrvRa)7`s5C*$_`mr|zdB(M{%c}L#%#yMn5Ol9V&%OJ;BqHojR8`$W;oW(5Ig1UXH zQacm%@5=P{_qrTgZU_{>U$c{)yKS zBv*LPtEOdi$t-QcMq4VU?4h>%fE)9LX4upsuJ(_#i!YGK6t5Odl^haKZl%LVOp=~S z)LIW|{e^rdfOmKzNpk)Ia8pt&3)@s4DzZ_Ga0q}qNNzsPkcNsi_mjf33Z9Unsvh14 z8)HkwX%Xk5(-bk3bj^Y36;{wDGJL-{H@b>#ysufp?)hPOeU+TJeplR>X!BU#fE=kTBl1sl=JE*g#L zIQ(c}wQN-wBQ;xZutz;`S0#7zK~Z0ppc0SHH*N!zqwqwYOQqx)dwbkSYH9Y3SAuhh zMewBLbWSBBcuT>UN}tO@30FVIF~ zhaMr;snJr-v2Dsa8U@S6+bP}V*-jSmv2@=i4dTJAwn_m>Z;CF)Qtu*u4yYTgzWS%x z$13KE9gn|@L4Htwdo1uVnLT`&{cZObe1>Q1t@h*BMLfTU>A`er5lx;Qc@;^##Iz3U zIIsN|szv+3jlXkTW32EhaQ_Q00fQ&oPHorJJeTOcABn8yfag!)p=w+RDB8_N&mIuC zvWr~H0I{%&)NSC{X;W+o1nvOa={YA>J_gU#l0jUb)Yh2amIYMAs{c-TA!bzQ@mlC` zR}GWjT4e{u^x;Cih5UK%@ymR)w>(=SdqS2y^>6o)PNLsGso_B6;@A9g@mRrGa4|%Q zUFhNP36S%j<8y9S%EEzrJ|;7=2YRX`Ry;;f@ZjUmFTOA6RJr5lZGC&wsqb+obax1J zs0{*hSdaI!KmZ1-!KQK^eEL)=v9!o;ECU7Z1rG&%-;D=rTTbr&9ON57X{ji48pxA= z^#4EuKK*jyb|5Uh%S5S1Y4kLm-@sZX!>9k%)IP640>oXFH?p}BGpAh_D--Tyu8umOQ{ zLz|&c*L&cuZnp^>Ouhf%hj{_dHf-~fmCJqb*l_uFFjMWje}Og$Y-9-!aL>|bebWb!T)f4`6`Q1qpPRtT5>M6m6);m#uVY#z!5_Q(!>Uwk%T zjA@FBA8)t`K5#+%=1H)fG?2vr?zM?7NNH3- zN;A)O7@8Xgpe{S01#6GuKV3L#|Djp^Ogqqz;^Ux-J<}j$a&Qb+gn=$kg~3oP6kxCa z6~=#s@qeLa)ILM+sfX};$T*tW97D0akN8wSXsJA-%u1I{NCsJodrCm-E55oPfGH`I zFLqd{UztQ%;o3Vz8(%@41SMcHBw5pToeod(9ioAwZ%a%=It$jG*q{Gfd|~Wuk%s}7 zpo9KQ^mqlJX*NWCXa4{kW3$;GFatwXfW7E>r>r9ZS*k{0Y{FMB#Pfr${-+xHFa#C3 zq3gibvmyFKG`NLVV!*wB`79i=UZ48GXIt+B$mkcAf0S-^Ir;!pE&{hZN2s958@k=$ zz~Ww=5x8pthN$FLw*%G}OZTy&)$-Dt_aFw-GcjkIQi37y&W+*KOgo#x`CfyHvi%x( zs88^WeaYwk({;RC@V!6OU4g5!^d@wq-OaGHIC(e9_nU7%TW zXzQyeH4)K`eRR<~l&zo~65cAXPZp1Uyt@XlVGqhJZ=qr|{TKVfhg#>41U5%m^t(`A z0wbed?C2%!44Hm8yzSQXRjYp)OpGe{um{$7MHza|@{eigzp-j#Y@E&dT{V=D8$zVX zVAvcr(n3VP9rJWN73M^3HTX>cz*2z#QTMh*S3pTDl z5=dtwU>C2ze^FxIfVB!eQYtR|*Vsi2sz1e+*7DVZqqJh3@0Z;RB=C_*ZpB5Ni zN?^T>U(yA_f09&11B-|VMv6-w$HwwYN{3+HuT6;0YmP}5NJiH=%XWb93HE4YI@DoJ z=mTXyP%{^;c~2(SNj1^!^hIY0lwR zpvrL`Kp=n6r|&y~`W*T|VNTr&1fc)fwFQTXzWS2G$+e~lL&x8!4sj7sM~Dnd)eB4T zg64g1yGc7=eUn~g%I6CB=ibc;_Hk@?-LACi+t=7=J-eq-J{>t=?V!H>?L7eFl`v#O zxb{umhmSM9O)xY)!pMm-LAtrf4+$jgH-UujJX5oV>NNCXUqzf*7RtFpuYYs%$a)ktn^HBWVT1m9_S`$2oN%M8ua2zs z@fvQqu@JluH>+pf>m(x549hEFoFS{ z^UZ_{0wbY|YX!xnCHvBbLtF??BQKOt^yci@%00k0`g-ziqne5>RNpj&%3X~~zrgD6 zb?pPN`qM3_DEruCEQM$O3|ax7*K#ZU=_T^Y*Vj6g;wWz^IFv0fq*lCcqEFZnoIt*~ z>S@X>SZq5bJB(?zQ{V8Ivof3;eb|1XGI+lf7mIJ#XA#xREV`}tVx57&-PreJ)vhv$ z;In9XM`6ngIriSsXQ4oABLWgg$8T!+n;rX5{D}?dWyofsx_(EOd4bcg3;a--VeN9T zLH?AsobmpWntjgD9*1}lBgdsZ1TEK|N{igL-KLT|A2U|i)~VOoHn;Y-rCv<*?P--W zT@M78l~Rx&;WXePEbp}?*jzAeFX$s!BIJ`gUfFySa+v)9$&<2X{sqxjumNhu+xG|E z3JS!y7o=-~&)_iVqY$_t#Scb#eQy*4Jnd0XILtFRsOz1@&e0;$J73h_^vF}wxiMEg zeZVa1f$EkKlD~=_*f~(>{|FU3K5-D@tVrkWOBOlK2@yb|358OXk2XKv*#TH={h%(9 zp$z2$4KBcUax}!6AH?NSAnIm$j}V4NH~&#}*?6Y*y?x>;*H_y9;uKmaK7Yn{U@d>V ziwF=n>_23}5}?h6gAm(zaAn_Qp??&eJiWV5c?x&ce{!r-Raua0%Hk)IIzn zO)gwhKjl`y!$p(wIaebrWkOT`=^$@zv8H13i<4YI<6~^&G?em(@gE+&lRISKQ5Uy5skCyOkY)5XkMIjEV#Tvc$iS^5Z>6umxym6Fk}^-9nK4H4b8kr zw}y__B>;b+;_a(1NPAg9e_UeywD5H&td;2~{h`Bz zL4k*mRcdJP=~;i+@$RDFPfp!M0kJgTkIp+pi$d3(690o`(&4R_Azu||*~@-X?3E7BJa|Vu zQFJ}7yJ=4gTUIYiJr~$4wt{ycrBiID597SbuK2_Dx|!e~`%l9r6g_hS)(jwS53V?kD-DCVH&;N(7~t3JunC(_Bl`3BQ696VYRdNkhO zK05m4lai&*B4P{!n^W!|C&bJI6zMG6FGZq!9o&CL$$ym6HCXlgv+qF9 zuDJ6z%Z3yP`uRh6r9ce!{%3Fg2Op9|Z!poFw~3=+t!tQqrzF2(90AJug7=RneLwxf zbI^~HJsk8dA_!GBBv{q`>o00qGEe%z~r1H4KuhSrt!=mks4 z2LB8aRv#xgOatu|3AFuBdZEYi644S#k_Yrmi}fYql3hguDV`vtZe!r^e%$K=b!)qK z2ooznqYqTHPWBMy;hI|*-BF8E;EDA>_5XB#lEMxl#mxoVkAVyQ?hrCzmWWrBaFQOf zD@XpK4zj#px8oMcC-F|P2_42s_mFElp89h05N3pn00+(mONl?LA|r+_42XTF4R+f& zB@bjRZ-s|NWwSQy{IiJtVBu}=`M`yk1sm6r3*EdaMgR%+!hZ$(U%~$WQ)4GBFiU&W z9mo}U86&Z((|t$FN~6{S6L=4khU2NpyO1r(V)*yAL{nyCYX%*BiIw*8jAr*+Cj9`` zE#AJ#V&cmo%?`sa4Ua5eJu2u8S-offYW04*l7cGejfNU*U`mqx?Zkl>MJMRE+1aY> z7xFgE$GS~qiS`8MCnT!rmF!t+x{W@2T(l5>;BohBto^X%0kcd zNkt;eY|wdLFOK7Vgv|Ju>%X&VG5!ed=U=@a$pxjMcU>>oV2^HY4~5ia)kKq}V?+7Z zwZg7A-dhyjrd0LxmtgNUdn>r6xmkCC{>yT~}pYb|<3`4|ZzsV@w+b|JV(`yOfC zOkW3c;@roXlO%gPT^_VSt}OE^X%5Y1+gB+r`YaXXoqTFt;;vH?KPww^!4}Q8&tUbaizDnJV*4p zgV&B$=RJX$;x}VocWb1=bs7%mmvqglu)=4f?N%nAkh3I?%W&;UE$S#cw zTV77^kMsiVWj;4BM##kUaJe`mm*FJj1o*{sCxDj}3a6a(B|Nfy34 zebQX@g`#%9?VFtq){tX5LERm6+fI4`7cBd0oi9o)d!OcS>2tANnLKUS#qIDkOT)lW z!RZCFx5~gsco8K>qQk}&<>iUOBSvvXh?@5_M+JdLsZw+Kx9SyMi<(qVH-kOoh@NEk zN&awsBE3yB(=0kfbZ7n;6&SW^s{HneIC3&%JtbpAcPuO!kp+V{1jl?h$79*MB`368 z{ISnA)-t~F#>bPjZ>ekCDh=CO6C9pEj*4 zi)7s$4MXpp6IFEFjF>`p(X@@Bshm#aPoCQ}ZA68%H*x4eE12plfW77@j{O{k$+q7V z8ffy$+LR~|%?l_}@yyX3b`d)#yxs7~vH-v5DP8gY;78tPPpRJM(0|c82x4ykJIDUT z3qch9=K1CwxSm_V&b^J1@}+ksG}MuKOXJjUT%7i%n7go^H8c#BR_7k{mA2LK`ry1W z>d}f%RtHmihl|aSY?+ajIv>-OvSRxlH?;7oEu9Z>q>ni`h+F^mAX>4MUA*v$G|2|_ zJo2&^A2?L9ulun{#MnD#4&cg~OT|hZT3hm#;T$@4R&Qo_rI1xC2Jj99_nbAdwQHYg z?M84hl0+tPvW}AX7@T`TJls*x2Ri=e|EFKAte|Z@Ut%jT&vBaR8cB3+^ekpwWNE6T znqVSQiR!s#{ZN!`Y~NaFZu6b7x<_&%mxrQlkw$3`Rwl+9LUc2<8y0yLv7s~JaTEF7 z>*bwPr6WD)fF8>N?mQ#87ITA!WeIu}yT$~=lkD4Sm&f6f(8YcAZ!fN&BHFFWRjX~X zIAD_M;<$UvhZxVIitBT{W-xec!O9(z$t-1a!(B7}jY7jA&ozxL;`-3Sv{04L_j`r} z#^LuN^NeBocjj65BD_vyA%0T}6D5#h;z;7mP6k#aL-P;OnW7?IiCqTs5RG+0agpt- zNV_iX?41%Vt#|rHs-q%&QX_Q*Xfgw9k2ulYDv-&8 zrK}wP(vpOgjd<%{ulB?N_0yMACk(szY8@*f?(QjB=$v)xF#bm zzS=!ylGDftpYtDIR`Gy<8NHrk{l*mF|%~3btGvlZiZCu#@UKJ(Zyy?TZFsu6`r;_ z)J64}&D~3>`HExGs;jw-Vu%{qB4+H0@)^v-svPZrL1LUV8G|DU$OM5;3LvQJ+*)LZ zpW#xYFW8+6L4HZLuuE5tNF7Y#>otpQh;&2^;-%d4gB2gu0#SZS2&|x?^d8HJ{qW+A z@mcqYTW+G0BHyNWmr{(@B+9OL3@Ce(Hu{H(ltTK)n}Md!||T3`TIkz)6|w{S1=!BUe3zzI5Iok9xEsoGrjwA~~K;8)DtLHYA-f{v?74 zylwVwb{HhBDgJq)UfOG0TR64pTquJ+bsMr!dAex`PiaaaBiwNqd}o)^l$r`NQthb- z#YWR(UYnnUwwUbrDOg%aYkYc1`X)z_lQl_4>HK`;ZOk`kcuX@^H3OZM@##V{3VXD$ zv^>aVx!O)w4+0~#7o$W^R`;Ygzq+oEMNOETIO@JDf5W;qCwnu-$kl3m=-rX66URkj zYHCr79^D`ANNt(TS1me{TY=$sK@bloyO^_s%`a$|akNg%hfgX-W@jMDi2E`E>c z97DEom%HmVjTybeh?_=-PDZ;KBn6Mhhd|0j^95giM9xVDJou7UR4+M=HH6Z z)vV~wD5EbB-AxnoyCbUP{z|gd!+hJh|1RpfS&J_5{SDUw73QxW_33XUrmVISBpyq9 zV3;wkosb;#VbpSUe#odp4#ZU66Lnx?=NyMIp?l{CE>mcokFvI?dQ~q?AK{df@1xJ1 z=KA)$oucVicsch8i52a3*t6j~E)mL@QA@lL--XYM|ieI%h7w zaN2O6`QA!}33+p1R&IF4hY#j&`xM(JKPa9=CbR^fH|Nau7gyIKtI2ICR(bVUY5a3m z&hBR;Vf`S^q!#rZROn7b}xiy(uLq3zaK2Sa|>0&#Tr z)U^N(L*nPker08N#!NY!SRe8X3vw{QZ5;(0pY<$+IYc;oHMTh8AyE0eZ&ebzM(sPBm@aw0!AHyg=Ze4c4Xk$J$7TqN!DShyZsot%F{8)ic?G}qqk>Y~z&;I_ngB^RP# zJ{wh1KGze?$eJu8P>pH+F=p@J3a9cI;pfdhTw^K8$>{V#NCm_4#hGdYF4jBj_ElqJ z(&!>Mu#{>v{$kz<6TjJma;qygc>^Wp*EQQ z){*RvWZ+aRK|$_^?>kI@?u`DlQfeO!8-cI*H*y9ZJPJq~zLu&In&tM)RjzIV`7#7^6gQ*w_r9x*K_nSnh?t10&z31B~xT z+sQsyGRbN(9%h8+HlM=83NuW*z9d9e4JkutRw>`V_zj;iulFGqs#51M&qY3g_GPK; zW$hyCzNpcqa>2PeRN;|(5j5fF9hH}YVLrxW`WED}Ip5l{k7v%m294j}IIC-&nS2*m zBEx-5c%3wQ(s&B@Sq9vcbIL3mn{+*eCRjOeu?6BirHW1;q#cn9M9-U!>#ucqA@o+7 z5;LwL;JqCpwPjU;Mfk{soaH!MLW|>2+z|hFaYJ}LXUO4?H^s%*(cVe_(!N?d}MfkLI~K`w%u zF1-f$X72&pp^tMg-l$SV*^L#snf$JcjRGgrOx3zXZA|5()EtVaMBB7|OYrTiQ-}vUG8}Tg0v+x8KS8Cz z1c%S2_*q4AC@Z%PUTIIxk9WKpWz;&Gn378slpu^u;E>`WC7e7+!Abui1?MEiv#CdE zPb3z$8x&L<@;s9fFt7Ia)JeEli%L4X^Z0VAP*?P3jT)#EB^)|juT!1w~#Eb!gTf-8u%H|DK4!n@LuWO(`bt3Ruk@= zd?GZIgffyrk0u#~Sd-}0boGMz&^4-Q7cmUvjK=;cfv$HIi=vxe7WC`z^1s@3y(X*s z*<7Pm=6ixVMOSOWCZ#O3aj`zpGW&=diz9T?XpJUUoiqm=L(o$iduR%k>_dGah*7)|Q^Zu~8;eLxCAFIgZ zIxBXuSTraDeZ0$-j1QdtnV4^cP`M%!90n5ai5Nc?g%_ikEQ6DNRUsS7jubC1dWm>& z>eekHEafiM19ct%%kwZjh00cJUI5K!%kJ40M{DXRjnVqS9?O|y{5%ufYdBmdGaXSI$l|cc z2yC=?lo?bHdJQ7`QhY3BNsq-XEi6j*Zi+AvLvJN!iVgf0!^I#_wTM1V$pWey7kOjx zq20uS-FPuZl&ImPt=u2aAI3Z!TbY@7i+HejnI2Bz2EIUtN2v|&Pp!8f2^0b4f_g?P z&&LY7J7|ixh3!Y^;nFlFskIF7hr{fN+{lD;Xqz=KJmH33YX}zzRMc=(XF=)eP@l`y zvLlePO=g}1J}=Jd@oR+M*(sJ$=vVyYff$@56hgtX0~%hRf$B_Nj9`FWELM8~L<#TI zQFdZz1%cfykjYaN6GNK;TlElo6$U&w4^#XRsLsU+ww3FRUL^2`y<)n&dI*bpSQDX- z#G(`mz4W{hm)s4aNQVP{tAn-@;Bxo`D+?i>#chjDq^zXr6VqR8Lm0{ z!+NL`F$6eiR9j}W&=1vwgiNRq3aU)Zytf3Q^yUksr1g`>olapMULSc#3QAd)#VITx zUwD-kT8%|LFC6ay!Ov2>l-l|cp4hd9Oi!0;9dF^@kK4w5BsKiHjyZ-Y9 z0IwwZpI0+6Og=tTIPNi@85QE`6cH}cG%e$t(FfUwW{|#@1UlKX!j1t;xo{d>m4DJZ z`9u^%OPFmc>DwG9Tb@*Y+eLt8udDEaMl3D1s?V=no zo}x$CfkabfVoo5nHtl4zs);8oEnd48n$L0W z#d%CCNuBqDtKisM+@|QDqOu0Kr6nc=NO2Rsck)!o1oDUaaz0qpnyIni@qbn{u)zQM zyGieYXxTmkZn}0G5sCpW$dIts4$KDWL}J`OA0N=+|J^H)&r^@;NQ2jZ_9WNZ+u7tX zAHO#>J&in$1C#0lN8bv`jyHrc&19C9w()A|x*;C;G5)hy|8##nWCLNK*DM-puyE#} z61KHav5;{0K7UR6a`ECVfyGZp@G$}NjS}ei(Z_cUaMBD)=5rL70GO=cGKBaM-Kc~J zC#MwAx{gFpl+vQ695yDMw(0iT^Qh5WbNyB+uyZ%F|7>hG8urvs*1?(gjPDXdntclUPOJ=_hwG%i zTw9_h*Z>OY2Tmb598{jhM&59tX%r3a)ZuCaXPQj8hlEC6GsUZwIaU{iwK5^h`4})8SFP0%6@{SGOpdX;Erugy>9KABxa@(DRhTVh4 zO6kL`D*FrsCcAHa#7zcKlwHCCR_n?5(>)&0dLZh9Gam4!$6r2ue_7}OIv2F2oX3!r zc=dRNM+w{2-RB=rsN+8hF1;dH$|-Q|Z^zs^*c>Es`9t7J9+ro-YoTz`aCk_eG4U<5 z!X%e-VhSVtYQkwXVx5DO_1C&#VBHz=qf|06(#Qnsg+U&(kvkah#$bJ}mz#qPVZ2&k z3waAw17|+hQSLS6z5+1>XDIJU7euX_ugA~FynG4_&zgA7<}nAY+P;O_Y|6Age%*4W z$5Lnxr4M2vo3(Ovjr>Ch)Lq+W#2=EoEogjN> z7qIB}>+ScE=T?h2p_;ZP%mkOkXy)^appUFiW85RH8%%q&a~zSqlyddyJ#S`c^N?e! zRWPk3jBkkcsx%OHzjHndZgNfx6kW=$<<^7x46}7S#`-m)mgW-8&mA77JZfJ{2~%uf zN5M=S#1jVuGT+wdNSLtI-0NmnYDCVcoxuup31_;%;}`>QUehDI6kU2q=1PSlo@Bzw zoh$r0O&6%H=?5WR48KBtbtcfJij9jN%csZx#_*df`c5W;_V<+G>uq#;M^VzwQeT|r zv;AvhBSVe;GYbya@>gVgrCkwvV{`TURLde|d!KzatvMJWQ{hb`q*$T+a`t8(c%=iJ z{z<2r)9$V6c{f?0`44ZQ*$5$`&e7k9I{2p+8Q`<2sCZ5uMI}T2c&B%6S*>HW0w%TJ zL>*$^*6~HGhAh7wd=Si`++nosBOt#9vomJVu)cs08q%7tphVU#sHOH*Yit*8XNe<6 zqg78fxp8ZRXNt=Vp`!aYo~vwM3Use4o4;|Lj|S#z7~R~ccBcJ8l=8~WoN(5UdAiLH zsIeiUB`Jnk58z+=bHhXooDo|!AtlD#`j%sLeQ9TEZ!e<*_%el) z{!Eg`_>YD0K1tnq1q`al&mfus+`{h~!uIkdDdxGBTfG$>?0E%a0wuV=YcqA$Bvj|` zQZgy1Di=RIMmou)9VL0fY2nKJ3-U5~I=tUTsdI18QocbM@I-WA1YlFAu$6Uvdy)(~w z8VfUPc)?1SYxLoHVozr(o@!VyPFf+gTNZ>YtnHbiyA>aA; zZvtr>^Uqw{dZLJVzO54H(XP^A2;clnL~&OA(I;4-H2P`)yLgxQ{{F?qmQ|1X77Ztb z>ts;65=a207fS;!M~oa;KW+Tl(>c6eP&R?VIfTvU>zew~ndpWbt=FDg#upjT_ji&dv4G7vKjx1e*pKna+ zd0#NRb*-LKS%;~%B0uKs=3{>)HW-}jH02OrY26r2XQGE^li_TBm>c?-K%h{qll=i~ z?{An8X9rZjeod#gZ9mtK((ZGaZ4?>)xTFdsHB+S6{Bcf9yCJ^g%K5?McoeMVd|UuN zCYGVfUSAB2tK@#fg+<-l=v}%38Dy%Xf2<{MPOT zX$MpkAoQ%j+!p_p<$OS%Jls}Cxm}3j@bE5V>$+sBERVka$kP2!LHj|Tga6Z5u0Yqk zqV*HPbtflc0&B8hmvO8L>&GOreOln_qljZeZ)nH$;dUzv(~qsSz6DcT*Y>FO-8{CY z_NYQx>pZAF_y|2Iz-RXKyYAU$XczcOz3*!8mIm5cHFpJ!ZS(8hw8~?Ld!;@oLOdz4 z&AS(Q{lR#1G#9TGuLEexiafTvb>z6YOCBZea8IC+(?0*)>bWT6R45u z=a^sHzhOOm+{WmeM`VNjDDIJv8ZS;-uvbq@s&{QFwYK0Il36W^bJZQ8N2IJwcMF|M z$?tl=Dn0P?`KK3|;j>2@>)voV>XY~A2Rh~nP~Tko6!>N=3{#H_B~r%6`<&i%eZ8x> zzOtF)WG|J#w}ogb#%-UTiyIh!mNp$uKeYh?xyyo%te%788c~~dQ4zK9weDgT1@GiS zTyO~PGnbRX_~obmQ~C7DI|ZfLKA{3(Tuk;&h0|1F*X+#rG0)*cp^y2@)C%^roFmN# zmh1FRY~Z!_`+Ov-j~i4n%Bah=Kc13ZGyS^M12tNcY9?;S|; zy2teT?h@i@yv*RRv&%ka~xKv?Lx^Ya)hrAe5cr$>6`b2q;;tsu&t zK^N5>Cb;nBey~zjYmQEHcs#cQ=)o3P0<#ekmA4CRN2_zqKg4S9u6CW>E=hiE=1mO> z!O4u^PmO9!T@mJzun{kYm(D`3MA1lcD^Z0+ksjal?ryN)6e&4_zL5Bs&8HGv#fwU3BXCekde2rd#C_KiZSZ@X6CP( z^8D%RMlJ84aW}a&=&&g7Lw$?akqIaLsPKN8xfc+j^-#&3+lmj+Zp{K&4||P zw|sOsoOyt)V)wlYgIAc`?x*4yj^+=9<-eQL=tzCMGz&⁡Q=EVWLisA9&nRu$=(| zX~53i#|O_mCvO5_=&#N|a0}E7%AUHjZol^=6_ukj9$7R}Ehla%(I1B)vT`1=&% ze0F76)a;40{@aimgpW1}Ed+fVt$8CX%4};V&xCoNo!>><*WbuanWI(nnsi%B%VT%h zaeMA;x~5((lDT(qO%I;hZFJMhTleH}T0-SDfQvnM67q*_9g+N5UkjjWfzDp+0Wc)u z{*8lN05Yj$U;mPVWCjTminqaFzW+?n+1eAk0LlAg&=!YC?H8A#%*irCLExZxlL zW}tyG3S9t;C1_c^5`B%hbipizVDc*F;q;XoawM1$fkV7Eeny!8uVGq4;=2FBwEu-^ zf%^YKOa27&{3kFiga{VEIRijZy`SgENm66MXQ8u21MAb$H_^Ix6U`08&=$s4Ht@HI z+6Z0ZU64Az>>QT_bE37so`4eqi^8BZMFFniOL9<7BTuZLW&ofM0(008hI&&x>;|0s zAU>24{9YOVb#v=3!aL_nN-yw?#Ky|!m8dDAb7^c+j@>~Au{YfdyouJV2x+dwqU^7q z=^;iY-2J#Ra2X`+R{q)-AR9UW4u7ST@+X+|e^FTi6uHG0K#{v8;y$Z_J7UT$sP_>Q zGC?>p;;qX+L)rgD|Nrkq|NmEuOf~m+ zJK7&5+yc(2RpQ@yn;QV`i3jnJ)M3~}A{G_toZm`=OrX@Cb8t7hIXBj*TLl6eib~W$ zw^t$fccFajNzm$s0uU9fyfa!7o!fM|fYl#{dC!r~gp*d+ohtMJmd41fQ6BP~jNsqJ zWWf6h0bsKCrhwxKsG9ocrmcVVn*TxFU6JMQNwDEj%ZO?R!vOnnjX>Nu>t%nts8LgK zoY!*N?c1Oc@(8d*$6ip;fja2X&>HZ9zgpMlSq`8AIVQ6(2M96%+GmvLfy&oCb<7Ey zex*HvbSMpLQOR*m2i2ysbvFg<(ldW8`;g~pY?K^r1@Ox$M9i?TGnXSm1a;0fz3V)N zY1Q+5z;#scpdr#2s>T~IftulOCB~RkWR?9VA=}v}JeptRH^0&?)UPC=1B9Y5=u01> zgl%ifiy|)yndIAjnQ!JAVCfP;s`B9Qp3WB5vW_hd$f7zHG6Q&ZE?klc=Hhr3v}V=w zgX=2K^ONhEeboljAS84 z>nU@@mpRkz8l7jJ(V{oLogm!a&lfG2G^6@hj=cwi2e{p&j)n8zNkcMcdQw~xOD$vuuW7i8Hr;Cw z9eRqBR&{cnA^_2o1jLivg_zuGfE@X=u$6l26=~ys-if|vCR0z%-vuYWK}`Lb<*Fm* zCVKG~%QbhU`kG@h;JMZ@X80pgcmgMf4xPpD>}T{WEvA(PI}~iIHaX2KP0i0eeuJB? zVX6OvwihIq1htsrKs+A+Yf^akr+T#4L8UxT`ctt7U`__lb?yhv7Jqa5Dr($j&j8=5 z^PP|62(%Iw*l_BhHfkZtyfN991}mJqXe{F@pkxU#DHiY7++B!raA7-BDhzjA$0g+D zFgF}OXJ;&P6=vt^wJX0h*>8s{J!?YLWZR?rN%FJdk&n;DYg;qdUmAOq_3GGS&kr=M ze(1Y`I2u5;JUC$z@U4%IeOzEU&!UPEGd`R%wkf{PN=@3$c%8ShU858Lq@f7zne73R zP9r|;evkR<;5;91(Dz>{8@gdty3;p#DGzg1P=i&lw6VCQ6BQrbXE?`xGDn)-Rz#Lp zB7eWIVWgU`g7HT4WRcdklx;!7rtswjvqC4{d>gfD;}Vm_LF#JXh7w6y+c6S>5%MId z4Kd~I$j1EXf@9YLVT!4t``6SYeM&_dGuQL;#L>AYuN}D=qXoQraEq4XDX4I?M&bvm z%x4BbXsuhZvT*BbouqG&^`P=ogO8228k!S6YpC+wd9$S@ z_KW4)$mV7KV|bO0hoXvbPkQXXUVs0(uv+=ZCGR{;7-kK zM_kB_oW;{r?bfbh@uB8=+9mc8-aXUHOMu43v%kQFbvGdOw=Cj25ihr5JA4BteO}QF zS+&G#9=1IGD0|Cw-Gx(S9L9b>7BDwciI)w-QD+;s^n!gFdEOeWh$;2?3f#t)+VN2@ z&xc*uU)!qWS$-8kKE!E0iLST1;Am}_Va2gM*vDsPZi;#&TZR=Zd=PoSA_nzj0)qQf zrxro>F~W7|cZ*zJhq_O2qD_s1GC25?i&k^(G^wIx#E|5%c@8GNLAiiIIb59wvZ~HY zJ56NPkk3CXr83-B1GDlxla#_`hq0P9rNt)R=E^E%W;i8svc z6Ny=IPAS)?P)@@m>c*hrZl3?ID&>lndJO{g%Yk#v_Su}1J);yUT>dn zy^wStjbxTx3l)o?DLr2*JR!%m!}sk5)wau3W%c1~ab};FrUPe!9-}`SN$Op^KN89p zUXb?Tc)f>&a`s5|2dX`z^4m*UC#^G-wa8fw&Z`(*u=X>~nbT-=d~(%xg5yo$+@r?9 zl#5gCO+z6fHMh_k@y*?*l1U)0S^BTu^zwOUD`>S8H^06;+mzOKrro5`QZwX!>BUOE z{yx3yV9{gJ{ixDL5mSD_HT zSM;pE)`qTNF>Y#9T)I%&my$(*Um%0NI&zlv=tyX_QJ>j>w$`#fOJYTTB_sAm&bk!= z{zU)1b0SX$RZ3#>=h$Uua{=o%E6Z|XleR|cGEj&3`VXAgDK);UE}$j`b*@zU{=s?@ zW)_PHB<$a4zqGWK4d&G}JM!4aHP>Q&)&qf`6OucO0sRhStT#qPuzUeTXacWrQ?Fpy zxa&;3N*QHjN9*h9)8K?;T+&?`CgD(UPE4s9Vh!0{h*S`b!! zyYpG^dDBhqzOIHR{Np-n&EtztmOJY5V+nBQ4Xe;Oynfq8bD)*_6~^QJpw12kGs{4ewSK?&-f0dkh@s z4)dR7S-ia`RGx}bEQ4AnnC)p_VI$VYM`-#$CsB$Gl=q=}M zEPd(^YaARc%<-ytnAPF~*MlEIt*-MR6VhvWxCkeq&?w|*XaqEb9Y9u~gG-T)Ugs>lV60BuIMflL1jxpv>#|p5Gigion~iu()1LLVWISxF zCan*j?7&KuaRoGWWi|J$Gt8?H5tkKO^VnC*7Sb}gf|{U`weX98LYZw^kD(${5%*=# z)}lFPSFhC=<)BZn6uyNByO_mU+3XL~y6Z0>CMF6Q8Sful!jv7r5P>j= zW6W7gS;QC@cFGGq!uAv$O4>47AgTbl`2j@~m7E%k)qX3{U~e;pSm;o06X~^N3)3urOoH#YrV=`9Ulk( z=qe#Qb4x0+u~KhHqp~LbSoT6T)xsZK-(}GEI5B8>&&V#IUNI$ghgJfuz^pe;bhr#C z;NHfMK0G<1Q^{n`sdXrJEkqUOBcWR^BHK_y>LWLs5$EullmmaD^ONevCBqye?I&Av zpLO2&!4@)AIVGJ|VcB0-yKr#+einyqD@(@kr;S<>a*`f5ztPA*-c;}hTNJ#N>@m=x zl!v_?VFrW9s>}%M(Oj9A4Im zM}Lu|oC>y9e?fP90(1fqsP5bzuxzmL?$M_cmr7Ke`L=rjS!IpqP3o)<9=A~s+Lyjh zo&ckGQQ${YgZRt=88Y`Y=!M+d`r?ZG^0~xdY+k^a4zqZ+%|*01Vmljj;!5o8ttaF~ zM_ZSVEn&DWoZ1mErfwvajb`jpHBaf^{bcyWx}^MOk@b+urBN%oO9$=B+l`mX*Y6dT zHH2T%d1sw=(ueijKizk)ZF3p5vZ( z`#F*4@rJqI5}hrG-uUQlrz5MdfgjQdi7AyIA$%kv?+JWM{-`szK5MRcg2r|O7aq#( z;INI;aKWfpOy%Vrd66-R+3Io%BKH-FWr5pf_6S}7P$20Wb&X16t8ir`x_!6DRG$ys z8{$&&gX}|7bO(Lig*_XP?f_7JC$)YqL`~nEo7>X&D+gb zy9ONFK~c*XW?VtUH=soF$b5q+bx?)Hk$cI#S2)d6wOw{7d*Cel>^#K&ruw6js%_x zc`Q6u2q%49N&K4lZ><003r8R>gws(_e*S=k77~3e@HDvf_&OKy z9}gx&eIL#8%Q`hs;+^^;LC+2!Y2wMcrUU2-#`UuOd~t=8mu)f;+{;-ypYQFAlxJJi zj)=vEVx0cc01k*oQ(VQ!nsj&1L6>+=>uG1PBJg9q_jjU4f6SyfFt%Vrs8($>*bpC{ zSpmMdNp}`ctkeEp_P4|c0{$g86Ed%S-@wxk{K*Cd$wdm{Re<{YRkF^p+Rc~v+S_J+ zD;I#}8#5{It+kWqx$49%+f3KijBP)XJCmb@HMH@HFu8D>S9=3f=xUo+Lf3x8xfrO? z5@T2?9ZuR-8zWktAJgyO3g|gdM`eWpR9*W-=X@RWJT(jXNilz-(BY#Buc6Z!VafOw zXyVMw+~W0d6(}+*U@;s1$T3p8&yTlNrw06d}jsFC}%MLyR+6 zqg=#v4ooT$FsTkknh=N%-X};2K)n<(?7h=?r{p2W#2UW%xOu!EG!a5GcqpBS{^SXN zKO8Kjk*7TA(K^r-H^K$~Cxro3OU~2au>U+u>_?>cSULg&cx%1j2y&Tk37=wPg<=lG zkHqFj%ptAE3nvs4qXvJ>zmrVx*AZra1YHW3jK+^2Ce#IEsak>wSH)r?Ay1ST@c zxY_=76579Lw0J#gCbQ^iV&n^j0=*yZNq&1zjm}n{WOqq~XRf+WYvc6;%eje0hbJuF zbWBO_Y@%j?f1@hczH4K-*YC~-!im}e(uQEmF6w>?>)d>Gp|^meVH&YA*T;;uG^h`> zrH=u)Yn2-4V*fs7?XPI;AbN4Ti|X{#{<08jBISSTx)|EQ@3#s#R8rl|lU^Fd{*M<*(^h%@=Yf3`P3 zHyfubK-Cl=JUjEAS9ux3xpjK}W1B33^IAv?rPewn@|o`=vg$_VyQq|v!*1Nreb=4!!2qcvD`kq?nEsaFvv z8`{{&<{|0P+=D3>pcmoqBSAoLip7fu?XA{Deq(gjLJyRxL1(Pgh=g(}I3Ui01A^dQ zXk6aE91wsJM8*w1W2utBs|9$fBx$^f4ioM`nEX)RU(?c`#wzGOy4(DLY&i0x{X8*& zZsmR9Q^_3A)tLMz!{wY^OPQlT77pX`|Kp}6oa6@;Ia?v)A6wtP`q-E9%uj0l~|c=kazwb_OV$6yCWnOaXB1zYKT;mgfi0 z^*h}$8}dOiWWNsc4<6|sd`S5-&_<49~(pgSV<+a3AS2gLq7!3S8E zA3V$NPhca4j_KlW$Mok9XhFTtAs2s%`R5ORC%bYSXw#1=hCrMC;PC#j$RNO7uTl&6 z&Fuc^1J3^++1A)oT*l$P` zIGryNUk9w=b%g_xb!p2JRxqO#V-4Jp?~SKx7YhEL_OARN>V1v7Qe7poOv}j9$&#q! znk3txMV7)jrAbCgn92+>HI8M7sUvO($FT?LJ@eyyzpu~h^L(HAKF|C8e4gh$&hE9gl#-PQ(z2I=Kx(g9=QEOUQ~zPo zz~b{|a$Czf8!AzBj7r#cdZ701wYA<|nAy~bkD}H1Uxq8+4Kyh<%87-aM9f~i`&DwN z^l4F$3Tj}OM~z_7i?G&BjGyIDq0qoxbF54O+O4?%2tyRrQT#rX)ce_keZl}Cucrhd!6;vymLc$&Ob#1zj>QO0}3yGY$Ae<@DALybwtxYdr z_y(9x`FW-_jfGu-`0>e)4e}-H5+V9lmywtK1IklOy^=OO&T0}r{GGmN!=HPQQCA=u zPhBgXE$F2E*gDQsW*PKYzU5qw87rfW-b9g&mL8rI0A|uTW`tODpp2_G=x|NCG^`4B zV__7}>7d}TaU2cei3AJpeVXfUxO4=K;`(2Q;47<@KiCjE+J8VXK-{?9h8SiREGDcAGto$)}uM3y8~#s#h<>u zMpfVFMa_(ow`8`lJ0VXQzc9)tjhZ#N!OAVswSUMNi?M=xzX7;L>O$qWpX1lqcy_xC#l zVuv?FThCKGM}Sc?{+XnryU>2jbRnOk?Nr!8%i61`)b2JX)H@ov9Z=e&y^p=J;+0L0 zBg2Kg71Du9r~1xBZlI@bM?2=4a>t8816=8($c7JugcetAP<(fLT37=_E70!xmH-d& zhgandq0#qa{hqQnMz4~@b(c^m9Vjl9MQ(sb7imdEciE>oOv?#SC!XN9kuP=(-cHn2 zk39CCi4UC^dp-VaBNmiCZ|v51RVZhu@79ejak)`lI$vg0mDQ)DYjERYMrf@`f82Ri zN(xh~;e~p*DWk=1qO^XifgOR5J50x)N*?3_O%qjJ`!}Wa+U0!CZEDc8iP&)uA8SXa zw>%*yx2GFPBKvcwOG$U)hv+XAY)MGd9JbXIZ8a~yUs-*1!Y04P^crYdUhKmZUfskP z^)If1`(1atqexd`TPVZe6$GvxcS=$&9PwQI$;smV-$N$EVI)_b6{{)A3-#g-b>yho z>LWpw6s|1ZJ|!l@aXRND)uRquv?gHEq1B0J4iE^Gf|`NgZ>r#NLc-PW_hTQy#>L2L zz(u)>jC+DR7XFS$UOq|yug7s*x`GJQz&77@RYIb)yO0L(`gKTO)g(p(r%T2=(l?8L z{+$OY`x=ULFN37RP@;d_K87~fA6C&^X=Rz$XTyF_=x2@_-!rom;wM2AtVHPFq=W7y57;hd8$0bP_12(WDMG+J7*(>zGd=pCYwKf z)-P7I-GNk^eV@EP86Ddg_$Eyr{AXPQY})}jmJO*b)u*moo)Sf1?5tKF@$+|FGwU;` z{G>K43>;0+%0KNvLR|4$A zaSi@)oru8XEkWE!g0eL0*_q6R*;grkU;y#NV$sp2-_DBPp6-Zk^e6bVf6e$3 zko_|V$arDCgcbjSm=b&~>a6BPOsx+Y9Sc@Dw(onGui5-iW>o>-Ig8jk;u(AA4@W$~Dj9(Bs=Hwn%}G{XwS#lzpe8{|m*>+*tqs literal 0 HcmV?d00001 diff --git a/docs/examples/fp8_primer.ipynb b/docs/examples/fp8_primer.ipynb index 788d6c37a..a8ebd770c 100644 --- a/docs/examples/fp8_primer.ipynb +++ b/docs/examples/fp8_primer.ipynb @@ -5,9 +5,9 @@ "id": "7b3e6954", "metadata": {}, "source": [ - "# Using FP8 with Transformer Engine\n", + "# Using FP8 and FP4 with Transformer Engine\n", "\n", - "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. In this example we will introduce the FP8 datatype and show how to use it with Transformer Engine.\n", + "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. Blackwell added support for NVFP4 and MXFP8 datatypes. In this example we will introduce these low precision datatypes and show how to use them with Transformer Engine.\n", "\n", "## Introduction to FP8\n", "\n", @@ -100,19 +100,66 @@ "" ] }, + { + "cell_type": "markdown", + "id": "fd7b4f37-50a2-4d41-9067-cf0c471cb2d7", + "metadata": {}, + "source": [ + "## Beyond FP8 - training with NVFP4\n", + "\n", + "In addition to MXFP8, NVIDIA Blackwell introduced support for an even smaller, 4-bit format called NVFP4. The values are represented there in E2M1 format, able to represent values of magnitude up to +/-6.\n", + "\n", + "
\n", + "\n", + "
Figure 8: FP4 E2M1 format can represent values between +/-6.
\n", + "
\n", + "\n", + "### NVFP4 Format\n", + "\n", + "NVFP4 format is similar to MXFP8 - it also uses granular scaling to preserve the dynamic range. The differences are:\n", + "\n", + " - Granularity of the scaling factors: in NVFP4 format a single scaling factor is used per block of 16 elements, whereas MXFP8 uses 1 scaling factor per block of 32 elements\n", + " - Datatype of the scaling factors: NVFP4 uses FP8 E4M3 as the scaling factor per block, whereas MXFP8 uses E8M0 as the scaling factor datatype. Choice of E4M3 for the scaling factor enables preservation of more information about mantissa, but does not enable the full dynamic range of FP32. Therefore, NVFP4 uses an additional single per-tensor FP32 scaling factor to avoid overflows.\n", + "\n", + "In the NVFP4 training recipe for weight tensors we use a different variant of the NVFP4 quantization, where a single scaling factor is shared by a 2D block of 16x16 elements. This is similar to the weight quantization scheme employed in [DeepSeek-v3 training](https://arxiv.org/abs/2412.19437v1), but with a much finer granularity.\n", + "\n", + "### NVFP4 training recipe\n", + "\n", + "The NVFP4 training recipe implemented in Transformer Engine is described in [Pretraining Large Language Models with NVFP4](https://arxiv.org/abs/2509.25149v1) paper. The main elements of the recipe are:\n", + "\n", + " - Stochastic Rounding. When quantizing gradients to NVFP4, we use stochastic rounding to avoid the bias introduced by quantization. With stochastic rounding values are rounded probabilistically to one of their two nearest representable numbers, with probabilities inversely\n", + "proportional to their distances.\n", + " - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n", + " - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n", + "disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n", + " - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `fp8_autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n", + "\n", + "The full linear layer utilizing NVFP4 is presented in Figure 9.\n", + "\n", + "
\n", + "\n", + "
Figure 9: Linear layer utilizing NVFP4
\n", + "
" + ] + }, { "cell_type": "markdown", "id": "cf5e0b0d", "metadata": {}, "source": [ - "## Using FP8 with Transformer Engine\n", + "## Using FP8 and FP4 with Transformer Engine\n", "\n", - "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n", + "Transformer Engine library provides tools enabling easy to use training with FP8 and FP4 datatypes using different strategies.\n", "\n", "### FP8 recipe\n", "\n", - "The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", - "Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training." + "Transformer Engine defines a range of different low precision recipes to choose from in the `transformer_engine.common.recipe` module.\n", + "\n", + " - The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", + " - [Float8CurrentScaling](../api/common.rst#transformer_engine.common.recipe.Float8CurrentScaling) recipe enables current per-tensor scaling with FP8.\n", + " - [Float8BlockScaling](../api/common.rst#transformer_engine.common.recipe.Float8BlockScaling) recipe enables block scaling with FP8 as described in [DeepSeek-v3 paper](https://arxiv.org/abs/2412.19437v1).\n", + " - [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) recipe enables MXFP8 training.\n", + " - [NVFP4BlockScaling](../api/common.rst#transformer_engine.common.recipe.NVFP4BlockScaling) recipe enables NVFP4 training." ] }, { @@ -122,12 +169,13 @@ "metadata": {}, "outputs": [], "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n", + "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling\n", "\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "mxfp8_format = Format.E4M3 # E4M3 used everywhere\n", - "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)" + "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)\n", + "nvfp4_recipe = NVFP4BlockScaling()" ] }, { @@ -135,7 +183,7 @@ "id": "f9591eb5", "metadata": {}, "source": [ - "This recipe is then used to configure the FP8 training." + "This recipe is then used to configure the low precision training." ] }, { @@ -235,13 +283,13 @@ { "data": { "text/plain": [ - "tensor([[ 0.2276, 0.2627, 0.3001, ..., 0.0346, 0.2211, 0.1188],\n", - " [-0.0963, -0.3725, 0.1717, ..., 0.0901, 0.0522, -0.3472],\n", - " [ 0.4526, 0.3482, 0.5976, ..., -0.0687, -0.0382, 0.1566],\n", + "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n", + " [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n", + " [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n", " ...,\n", - " [ 0.1698, 0.6061, 0.0385, ..., -0.2875, -0.1152, -0.0260],\n", - " [ 0.0679, 0.2946, 0.2751, ..., -0.2284, 0.0517, -0.1441],\n", - " [ 0.1865, 0.2353, 0.9172, ..., 0.1085, 0.1135, 0.1438]],\n", + " [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n", + " [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n", + " [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)" ] }, @@ -263,13 +311,13 @@ { "data": { "text/plain": [ - "tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.0233, 0.2498, 0.1131],\n", - " [-0.0767, -0.3778, 0.1862, ..., 0.0858, 0.0676, -0.3369],\n", - " [ 0.4615, 0.3593, 0.5813, ..., -0.0779, -0.0349, 0.1422],\n", + "tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.1134, -0.3661, 0.1650],\n", + " [-0.0767, -0.3778, 0.1862, ..., -0.1370, -0.8448, -0.1770],\n", + " [ 0.4615, 0.3593, 0.5813, ..., 0.1696, -0.8826, -0.1826],\n", " ...,\n", - " [ 0.1914, 0.6038, 0.0382, ..., -0.2847, -0.0991, -0.0423],\n", - " [ 0.0864, 0.2895, 0.2719, ..., -0.2388, 0.0772, -0.1541],\n", - " [ 0.2019, 0.2275, 0.9027, ..., 0.1022, 0.1300, 0.1444]],\n", + " [ 0.1914, 0.6038, 0.0382, ..., 0.4049, -0.4729, 0.0118],\n", + " [ 0.0864, 0.2895, 0.2719, ..., -0.3337, -0.4922, 0.1240],\n", + " [ 0.2019, 0.2275, 0.9027, ..., 0.0706, -0.5481, 0.1356]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)" ] }, @@ -300,13 +348,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.0346, 0.2211, 0.1188],\n", - " [-0.0963, -0.3724, 0.1717, ..., 0.0901, 0.0522, -0.3470],\n", - " [ 0.4526, 0.3479, 0.5976, ..., -0.0686, -0.0382, 0.1566],\n", + "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n", + " [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n", + " [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n", " ...,\n", - " [ 0.1698, 0.6062, 0.0385, ..., -0.2876, -0.1152, -0.0260],\n", - " [ 0.0679, 0.2947, 0.2750, ..., -0.2284, 0.0516, -0.1441],\n", - " [ 0.1865, 0.2353, 0.9170, ..., 0.1085, 0.1135, 0.1438]],\n", + " [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n", + " [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n", + " [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)\n" ] } @@ -339,19 +387,14 @@ { "data": { "text/plain": [ - "tensor([[ 4.9591e-05, -1.9073e-04, 9.5367e-05, ..., -3.8147e-06,\n", - " 4.1962e-05, 2.2888e-05],\n", - " [ 2.2888e-05, -3.4332e-05, 2.2888e-05, ..., 2.6703e-05,\n", - " 5.3406e-05, -1.4114e-04],\n", - " [-3.8147e-05, 2.6703e-04, -3.8147e-06, ..., -5.7220e-05,\n", - " 4.1962e-05, -1.9073e-05],\n", + "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", - " [ 1.1444e-05, -7.2479e-05, -3.8147e-06, ..., 5.3406e-05,\n", - " -1.5259e-05, 2.2888e-05],\n", - " [ 4.9591e-05, -9.5367e-05, 6.8665e-05, ..., -1.5259e-05,\n", - " 7.6294e-05, 4.5776e-05],\n", - " [-1.5259e-05, -7.6294e-06, 1.8692e-04, ..., -3.0518e-05,\n", - " -4.5776e-05, 7.6294e-06]], device='cuda:0', grad_fn=)" + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0',\n", + " grad_fn=)" ] }, "execution_count": 7, @@ -370,6 +413,53 @@ "source": [ "The differences in result coming from FP8 execution do not matter during the training process, but it is good to understand them, e.g. during debugging the model." ] + }, + { + "cell_type": "markdown", + "id": "d45e8b6c-803b-4a4f-8835-c19b0a94bc6a", + "metadata": {}, + "source": [ + "### Using multiple recipes in the same training run\n", + "\n", + "Sometimes it is desirable to use multiple recipes in the same training run. An example of this is the NVFP4 training, where a few layers at the end of the training should be run in higher precision. This can be achieved by using multiple autocasts, either completely separately or in a nested way (this could be useful when e.g. we want to have a configurable overarching recipe but still hardcode a different recipe for some pieces of the network)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c663f694-41d6-47c0-a397-5fc56e692542", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.0547, 0.0039, -0.0664, ..., -0.2061, 0.2344, -0.3223],\n", + " [ 0.0131, -0.1436, 0.0168, ..., -0.4258, 0.1562, -0.0371],\n", + " [ 0.1074, -0.2773, 0.0576, ..., -0.2070, 0.0640, -0.1611],\n", + " ...,\n", + " [ 0.0825, -0.0630, 0.0571, ..., -0.3711, 0.1562, -0.4062],\n", + " [-0.1729, -0.1138, -0.0620, ..., -0.4238, 0.0703, -0.2070],\n", + " [-0.0908, -0.2148, 0.2676, ..., -0.4551, 0.1836, -0.4551]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=<_LinearBackward>)\n" + ] + } + ], + "source": [ + "my_linear1 = te.Linear(768, 768).bfloat16() # The first linear - we want to run it in FP4\n", + "my_linear2 = te.Linear(768, 768).bfloat16() # The second linear - we want to run it in MXFP8\n", + "\n", + "inp = inp.bfloat16()\n", + "\n", + "with te.fp8_autocast(fp8_recipe=nvfp4_recipe):\n", + " y = my_linear1(inp)\n", + " with te.fp8_autocast(fp8_recipe=mxfp8_recipe):\n", + " out = my_linear2(y)\n", + "\n", + "print(out)\n", + "\n", + "out.mean().backward()" + ] } ], "metadata": { From 0db0f4d2d7ca7ae6e761294aedc74b6e30a8aaf4 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 6 Oct 2025 12:05:31 -0400 Subject: [PATCH 034/141] [JAX] Fix for GEMM + fuse bias + AllReduce (#2230) * not fuse bias for output all reduction case + unit tests Signed-off-by: Phuong Nguyen * norm to reduce dgamma along tpsp as well Signed-off-by: Phuong Nguyen * clean up tests Signed-off-by: Phuong Nguyen * fix test_distributed_layernorm byte counts Signed-off-by: Phuong Nguyen * increase tols for jax_gemm Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- tests/jax/distributed_test_base.py | 17 +- tests/jax/test_distributed_dense.py | 253 ++++++++++++++++++ tests/jax/test_distributed_layernorm.py | 19 +- tests/jax/test_distributed_layernorm_mlp.py | 51 +++- transformer_engine/jax/cpp_extensions/gemm.py | 84 +++--- .../jax/cpp_extensions/normalization.py | 6 +- .../jax/csrc/extensions/gemm.cpp | 16 +- transformer_engine/jax/sharding.py | 15 ++ 8 files changed, 382 insertions(+), 79 deletions(-) create mode 100644 tests/jax/test_distributed_dense.py diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 7c08539c3..4693086b8 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -17,14 +17,6 @@ def generate_configs(): configs = [] - if is_devices_enough(2): - configs.append( - pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") - ) - configs.append( - pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2") - ) - if is_devices_enough(4): configs.append( pytest.param( @@ -32,10 +24,17 @@ def generate_configs(): (2, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp"), - id=f"n4_dp2_tp2", + id="n4_dp2_tp2", ) ) + if is_devices_enough(2): + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"), + ) return configs diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py new file mode 100644 index 000000000..9541ccfcb --- /dev/null +++ b/tests/jax/test_distributed_dense.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from functools import partial + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.dense import dense + + +DTYPES = [jnp.bfloat16] + +GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in] + +WEIGHT_SHAPES = [[256, 256]] # [hidden_in, hidden_out] + + +def _generate_inputs(input_shape, weight_shape, dtype): + """Generate test inputs for GEMM operations""" + _, _, hidden_in = input_shape + hidden_in_w, hidden_out = weight_shape + assert hidden_in == hidden_in_w, f"Dimension mismatch: {hidden_in} != {hidden_in_w}" + + bias_shape = (hidden_out,) + + # Generate random inputs + x = random.normal(random.PRNGKey(1124), input_shape, dtype=dtype) + weight = random.normal(random.PRNGKey(2248), weight_shape, dtype=dtype) / jnp.sqrt(hidden_in_w) + bias = random.normal(random.PRNGKey(3372), bias_shape, dtype=dtype) / jnp.sqrt(hidden_out) + + return x, weight, bias + + +def _get_sharding_for_gemm(mesh, mesh_resource, partition_layout="rowwise"): + """Get sharding patterns for GEMM inputs and outputs""" + + dp_axis = mesh_resource.dp_resource + tp_axis = mesh_resource.tpsp_resource + + if partition_layout == "colwise": + x_spec = PartitionSpec(dp_axis, None, None) + weight_spec = PartitionSpec(None, tp_axis) + bias_spec = PartitionSpec(tp_axis) + output_spec = PartitionSpec(dp_axis, None, tp_axis) + elif partition_layout == "rowwise": + x_spec = PartitionSpec(dp_axis, None, tp_axis) + weight_spec = PartitionSpec(tp_axis, None) + bias_spec = PartitionSpec(None) + output_spec = PartitionSpec(dp_axis, None, None) + else: + raise ValueError(f"Invalid partition: {partition_layout}") + + x_sharding = NamedSharding(mesh, x_spec) + weight_sharding = NamedSharding(mesh, weight_spec) + bias_sharding = NamedSharding(mesh, bias_spec) + output_sharding = NamedSharding(mesh, output_spec) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +@partial(jax.jit, static_argnames=("contracting_dims", "output_sharding")) +def _jitted_gemm(x, weight, bias, contracting_dims, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + fuse_bias=True, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +# TODO(Phuong): +# 1. Add supported recipes after FP4 is added +# 2. Add communication type/byte checks +class TestDistributedDense: + """Test distributed GEMM without collective operations vs JAX dot""" + + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_configs(), + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES) + @pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES) + @pytest_parametrize_wrapper("partition", ["rowwise", "colwise"]) + def test_distributed_gemm( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + dtype, + input_shape, + weight_shape, + partition, + ): + """Test TE GEMM against JAX dot with bf16 dtype""" + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Generate inputs + x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype) + + # Get sharding patterns + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm( + mesh, mesh_resource, partition_layout=partition + ) + + # Shard inputs + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension + + with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + # TE GEMM result + te_result = _jitted_gemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=contracting_dims, + output_sharding=output_sharding, + ) + + # JAX dot reference result + jax_result = ( + jax.lax.dot_general( + x_sharded, weight_sharded, dimension_numbers=(contracting_dims, ((), ())) + ) + + bias_sharded + ) + + assert te_result.sharding == jax_result.sharding + # Ensure computation is complete + jax.block_until_ready(te_result) + jax.block_until_ready(jax_result) + + # Gather results for comparison + gathered_te = jax.lax.with_sharding_constraint( + te_result, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_jax = jax.lax.with_sharding_constraint( + jax_result, NamedSharding(mesh, PartitionSpec(None)) + ) + + # Compare results + assert_allclose(gathered_te, gathered_jax, dtype=dtype) + + def _te_sum_dense(self, x, weight, bias, contracting_dims): + """TE GEMM function for gradient testing""" + return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims)) + + def _jax_sum_dense(self, x, weight, bias, contracting_dims): + """JAX dot function for gradient testing""" + result = ( + jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias + ) + return jnp.sum(result) + + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_configs(), + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES) + @pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES) + @pytest_parametrize_wrapper("partition", ["rowwise", "colwise"]) + def test_te_distributed_dense_grad( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + dtype, + input_shape, + weight_shape, + partition, + ): + """Test TE GEMM gradients against JAX dot gradients""" + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Generate inputs + x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype) + + # Get sharding patterns + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm( + mesh, mesh_resource, partition_layout=partition + ) + + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + contracting_dims = ((2,), (0,)) + + with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + # Test gradients w.r.t. all inputs + te_grad_func = jax.jit( + jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), + static_argnames=("contracting_dims",), + ) + jax_grad_func = jax.jit( + jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)), + static_argnames=("contracting_dims",), + ) + + te_val, te_grads = te_grad_func( + x_sharded, weight_sharded, bias_sharded, contracting_dims + ) + jax_val, jax_grads = jax_grad_func( + x_sharded, weight_sharded, bias_sharded, contracting_dims + ) + + # Compare forward pass + assert_allclose(te_val, jax_val, dtype=dtype) + + # Compare gradients + for i, (te_grad, jax_grad) in enumerate(zip(te_grads, jax_grads)): + te_grad_spec = tuple(i for i in te_grad.sharding.spec if i is not None) + jax_grad_spec = tuple(i for i in jax_grad.sharding.spec if i is not None) + assert te_grad_spec == jax_grad_spec, f"Gradient sharding mismatch at te_grads[{i}]" + gathered_te_grad = jax.lax.with_sharding_constraint( + te_grad, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_jax_grad = jax.lax.with_sharding_constraint( + jax_grad, NamedSharding(mesh, PartitionSpec(None)) + ) + assert_allclose( + gathered_te_grad, + gathered_jax_grad, + dtype=dtype, + err_msg=f"Gradient mismatch for argument {i}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index f3296277c..5fa08fa08 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -66,18 +66,19 @@ def generate_collectives_count_ref( self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) + # TODO(Phuong) is_dp_enabled = dp mesh axis size > 1 is_dp_enabled = mesh_resource.dp_resource is not None + is_tpsp_enabled = mesh_resource.tpsp_resource is not None assert ln_type in ["layernorm", "rmsnorm"] - all_reduce_loss_bytes = 4 # 1 * FP32 - # for loss, dgamma and dbeta - # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp - weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1 - allreduce_total_bytes = ( - all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize - ) - other_bytes = 0 + # loss, 1 FP32 + allreduce_total_bytes = 4 if is_dp_enabled else 0 + # dgamma and dbeta + weight_count = 2 if ln_type == "layernorm" else 1 + allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize return generate_collectives_count( - allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes + allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled), + allgather=0, + other=0, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index a44921c64..d38f43d00 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -48,7 +48,7 @@ SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) @@ -59,19 +59,47 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -INTERMEDIATE = 64 +INTERMEDIATE = 256 # Only test with FSDP and TPSP as DP is not used def generate_fsdp_and_tpsp_configs(): configs = [] + if is_devices_enough(4): + configs.append( + pytest.param( + [ + 4, + (2, 2), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp2_tpsp2", + ) + ) + if is_devices_enough(2): configs.append( - [2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] + pytest.param( + [ + 2, + (1, 2), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp1_tpsp2", + ) ) - if is_devices_enough(4): configs.append( - [4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] + pytest.param( + [ + 2, + (2, 1), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp2_tpsp1", + ), ) return configs @@ -229,10 +257,7 @@ def _test_layernorm_mlp_grad( fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 - if fwd_test_type == jnp.float16 and use_bias: - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) - else: - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: @@ -381,6 +406,7 @@ def _test_layernorm_mlp( assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) + # TODO(Phuong): check if these tols updates are still needed atol = None rtol = None l40_tolerance_update = ( @@ -404,9 +430,10 @@ def _test_layernorm_mlp( # within tolerance to the float32 ground truth. jax_triton_gemm_precision_tolerance_update = ( with_jax_gemm - and isinstance(fp8_recipe, recipe.Float8CurrentScaling) - and dtype == jnp.bfloat16 - and activation_type == ("gelu", "linear") + and fp8_recipe is not None + and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) + and dtype in (jnp.bfloat16, jnp.float16) + and activation_type == ("gelu", "linear"), ) if jax_triton_gemm_precision_tolerance_update: atol = 0.08 diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e5fcdac3c..865efe89d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -451,23 +451,19 @@ def _dims_are_consecutive(dims): output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) # Validate bias - bias_shape = (0,) - bias_dtype = out_dtype if fuse_bias: - expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) - if not grad: - assert bias.size == expected_bias_size, ( - "cuBLAS GEMM bias tensor has incorrect shape, " - f"expected ({expected_bias_size}, ) but found {bias.shape}." - ) - assert bias.dtype == out_dtype, ( - "cuBLAS GEMM bias tensor has incorrect data type, " - f"expected {bias_dtype} but found {bias.dtype}." - ) - bias_shape = bias.shape - else: - bias_shape = rhs_non_contracting_shape - bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + assert bias.shape == tuple(rhs_non_contracting_shape), ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {out_dtype} but found {bias.dtype}." + ) + # WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we + # change the fuse_bias value in the sharded_impl + dbias_shape = bias.shape if grad else (0,) + bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype) # Validate pre-GeLU pre_gelu_shape = (0,) @@ -548,7 +544,7 @@ def lowering( } operand_output_aliases = {} - if fuse_bias and not grad: + if grad: operand_output_aliases.update({4: 1}) # bias <-> bias_grad if fuse_gelu and grad: operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out @@ -927,7 +923,6 @@ def infer_sharding_from_operands( del ( out_dtype, scaling_mode, - grad, use_split_accumulator, result_infos, is_outer, @@ -941,8 +936,8 @@ def infer_sharding_from_operands( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: + # Discard dbias gradient spec if there is no bias and grad fusion + if not (fuse_bias and grad): dbias_specs = (None,) dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) @@ -1008,8 +1003,8 @@ def partition( # Assemble output shardings out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: + # Discard bias gradient spec if there is no bias and grad fusion + if not (fuse_bias and grad): dbias_specs = (None,) out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) @@ -1019,6 +1014,8 @@ def partition( out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + # We should not fuse bias in the output reduction case + sharded_fuse_bias = fuse_bias and reduce_spec is None outputs = GemmPrimitive.impl( lhs, lhs_scale_inv, @@ -1029,7 +1026,7 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, - fuse_bias=fuse_bias, + fuse_bias=sharded_fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, @@ -1039,13 +1036,17 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): collective_op=collective_op, ) - if reduce_spec is not None and not collective_op.is_reduce_scatter: - if is_all_reduce_in_float32(): # For unittest only - outputs[0] = jax.lax.psum(outputs[0].astype(jnp.float32), reduce_spec).astype( - out_dtype - ) - else: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + if reduce_spec is not None: + if not collective_op.is_reduce_scatter: + if is_all_reduce_in_float32(): # For unittest only + outputs[0] = jax.lax.psum( + outputs[0].astype(jnp.float32), reduce_spec + ).astype(out_dtype) + else: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + + if fuse_bias: # TODO(Phuong): rename fuse_bias to has_bias + outputs[0] += bias return outputs @@ -1068,7 +1069,7 @@ def shardy_sharding_rule( operand_types, result_types, ): - del out_dtype, grad, use_split_accumulator + del out_dtype, use_split_accumulator del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer if not collective_op.is_none: @@ -1079,12 +1080,6 @@ def shardy_sharding_rule( prefix = "Gemm_" - warnings.warn( - "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," - " please turn off Shardy by exporting the environment variable" - " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems." - ) - def _generate_operand_rules(name, ndim, cdims): specs = [] ldims = tuple(i for i in range(ndim) if i not in cdims) @@ -1118,7 +1113,8 @@ def _generate_operand_rules(name, ndim, cdims): rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) bias_spec = rhs_non_cspec if fuse_bias else ("…4",) - gelu_spec = out_spec if fuse_gelu else ("…5",) + dbias_spec = bias_spec if grad else ("…5") + gelu_spec = out_spec if fuse_gelu else ("…6",) return SdyShardingRule( operand_mappings=( @@ -1131,7 +1127,7 @@ def _generate_operand_rules(name, ndim, cdims): ), result_mappings=( out_spec, - bias_spec, + dbias_spec, gelu_spec, ), ) @@ -1161,6 +1157,13 @@ def _te_gemm( collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: + if grad or fuse_gelu: + warnings.warn( + "GEMM + fused grad or fused gelu is not well tested and will be deprecated in the" + " future", + DeprecationWarning, + ) + # Prepare non-quantized GEMM operands lhs_data = lhs rhs_data = rhs @@ -1228,7 +1231,7 @@ def _te_gemm( grad=grad, use_split_accumulator=use_split_accumulator, transpose_batch_sequence=transpose_batch_sequence, - sequence_dim=-1, + sequence_dim=-1, # Dummy value and will be set in the primitive is_outer=True, collective_op=collective_op, ) @@ -1618,6 +1621,7 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + # TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None fuse_bias = kwargs.get("fuse_bias", False) fuse_gelu = kwargs.get("fuse_gelu", False) if not GemmPrimitive.enabled(): diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3348c725b..ef6373688 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -28,7 +28,7 @@ get_cudnn_version, ) from .quantization import _quantize_dbias_impl, AmaxScope -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, @@ -801,9 +801,9 @@ def sharded_impl(dz, x, mu, rsigma, gamma): norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, ) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) + global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh) if norm_type == NVTE_Norm_Type.LayerNorm: - global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) + global_dbeta = all_reduce_sum_along_dp_fsdp_tpsp(local_dbeta, mesh) else: global_dbeta = local_dbeta return local_dx, global_dgamma, global_dbeta diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 1467fa887..f2007efcf 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -158,18 +158,18 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; - std::vector bias_shape = {0}; + size_t bias_size = 0; DType bias_dtype = out_dtype; if (fuse_bias) { - if (!grad) { + if (grad) { NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); } - bias_ptr = bias_grad->untyped_data(); - bias_shape.at(0) = bias_grad->dimensions().front(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + bias_ptr = bias.untyped_data(); + bias_size = product(bias.dimensions()); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); } - auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + auto bias_ = TensorWrapper(bias_ptr, std::vector{bias_size}, bias_dtype); // Pre-GeLU output from forward pass or input to backward pass void *pre_gelu_ptr = nullptr; @@ -202,6 +202,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", to_string_like(out_shape), " but got ", output->element_count(), " elements ", to_string_like(output->dimensions())); + NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, + ", out_shape[1]=", out_shape[1]); nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), rhs_transposed, lhs_transposed, grad, workspace_.data(), false, @@ -220,6 +222,8 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i buffer_shape[1] = out_shape[1]; out_shape[0] = out_shape[0] / comm_handler.tp_size; } + NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, + ", out_shape[1]=", out_shape[1]); auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( buffer_shape, buffer_dtype, collective_op); if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index d3a7952d3..8eeaca4cc 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -365,6 +365,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) +def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh): + """Perform all-reduce sum operation along data parallelism and sequence parallelism axes. + + Args: + x: Input tensor to reduce + mesh: JAX mesh for distributed computation + + Returns: + Reduced tensor + """ + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh) + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) + + def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): """Perform all-reduce max operation along all axes except pipeline parallelism. From 56e2fede5ace495c0aa817e802ea3504c5974b11 Mon Sep 17 00:00:00 2001 From: Kiv Chen <34561254+KivenChen@users.noreply.github.com> Date: Mon, 6 Oct 2025 09:38:57 -0700 Subject: [PATCH 035/141] [Build] fix: TE installation failed to find uv-installed cuDNN libraries (#2207) [Build] fix: python platlib path Signed-off-by: Kiv Chen Co-authored-by: Kirthi Shankar Sivamani --- build_tools/build_ext.py | 1 + 1 file changed, 1 insertion(+) diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 3aa45f024..349858ac4 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -57,6 +57,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: build_dir, f"-DPython_EXECUTABLE={sys.executable}", f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}", + f"-DPython_SITEARCH={sysconfig.get_path('platlib')}", f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", ] From 9f3e79bff824d3a9f10267dc414308011c87b093 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 6 Oct 2025 15:01:04 -0700 Subject: [PATCH 036/141] =?UTF-8?q?[PyTorch]=20Fix=20tests=20for=20?= =?UTF-8?q?=F0=9F=A4=97=20integration=20(#2239)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update test requirements for HF Signed-off-by: Kirthi Shankar Sivamani * Update build_tools/pytorch.py Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/pytorch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index a974e370d..3d44d8740 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -19,7 +19,7 @@ def install_requirements() -> List[str]: def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions.""" - return ["numpy", "torchvision", "transformers"] + return ["numpy", "torchvision", "transformers", "torchao==0.13"] def setup_pytorch_extension( From 127b6d3ab3088c403f3b38b8405b70fc33ee3f34 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 7 Oct 2025 10:59:24 -0400 Subject: [PATCH 037/141] [JAX] Activation/Normalization to output amax for later quantization in CurrentScaling (#2238) * reuse amax for current scaling Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 371 ++++++++++++++---- .../jax/cpp_extensions/normalization.py | 198 ++++++++-- .../jax/cpp_extensions/quantization.py | 61 +-- .../jax/csrc/extensions/activation.cpp | 99 +++-- .../jax/csrc/extensions/normalization.cpp | 60 +-- .../jax/csrc/extensions/quantization.cpp | 6 +- transformer_engine/jax/dense.py | 30 +- transformer_engine/jax/flax/module.py | 16 + transformer_engine/jax/flax/transformer.py | 5 + transformer_engine/jax/layernorm_dense.py | 19 + transformer_engine/jax/layernorm_mlp.py | 41 +- 11 files changed, 677 insertions(+), 229 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 925c1d01a..be1f9f956 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -148,7 +148,6 @@ class ActLuPrimitive(BasePrimitive): name = "te_act_lu_ffi" multiple_results = True impl_static_args = ( - 2, 3, 4, 5, @@ -156,7 +155,11 @@ class ActLuPrimitive(BasePrimitive): 7, 8, 9, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer, act_params + 10, + 11, + 12, + 13, + ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer inner_primitive = None outer_primitive = None @@ -164,6 +167,7 @@ class ActLuPrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, *, out_dtype, act_enum, @@ -171,16 +175,23 @@ def abstract( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_act_lu_p abstract """ - del act_enum, act_params + del act_enum, act_params, amax_scope, transpose_batch_sequence + assert ( + not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value + ), f"scaling_mode = {scaling_mode}" dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 assert x_aval.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x_aval.shape} and act_len {act_len}" @@ -215,6 +226,7 @@ def lowering( ctx, x, scale, + amax, *, out_dtype, act_enum, @@ -222,24 +234,34 @@ def lowering( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_gated_act_lu_p lowering rules """ - del out_dtype, scale_dtype, act_len, is_outer - x_aval, scale_aval = ctx.avals_in + del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence + x_aval, scale_aval, amax_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( + assert amax_aval.dtype == jnp.float32 + + out = ffi.ffi_lowering( + ActLuPrimitive.name, + operand_output_aliases={2: 4}, # donate amax buffer to updated_amax + )( ctx, x, scale, + amax, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x, act_params=act_params.to_ffi_lowering_dict(), + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return out @@ -247,14 +269,18 @@ def lowering( def impl( x, scale, + amax, out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ to describe implementation @@ -266,14 +292,18 @@ def impl( ActLuPrimitive.inner_primitive.bind( x, scale, + amax, out_dtype=out_dtype, act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, - is_outer=False, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=False, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -301,17 +331,19 @@ def batcher( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ to describe batch rules for vmap """ - del act_len, is_outer check_valid_batch_dims(batch_dims) assert ActLuPrimitive.outer_primitive is not None - x, scale = batched_args - x_bdim, scale_bdim = batch_dims + x, scale, amax = batched_args + x_bdim, scale_bdim, _ = batch_dims amax_bdim = scale_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim @@ -319,12 +351,18 @@ def batcher( ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=out_dtype, act_enum=act_enum, + act_len=act_len, scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -337,8 +375,11 @@ def infer_sharding_from_operands( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, @@ -349,8 +390,11 @@ def infer_sharding_from_operands( act_enum, scale_dtype, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -402,13 +446,16 @@ def partition( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, ): - del result_infos, is_outer # Unused. + del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) @@ -452,26 +499,40 @@ def partition( amax_sharding, ) - def sharded_impl(x, scale): - local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = ( - ActLuPrimitive.impl( - x, - scale, - out_dtype=out_dtype, - act_enum=act_enum, - act_len=act_len, - scaling_mode=scaling_mode, - is_2x=is_2x, - scale_dtype=scale_dtype, - is_outer=True, - act_params=act_params, - ) + def sharded_impl(x, scale, amax): + ( + local_x, + local_colwise_x, + local_scale_inv, + local_colwise_scale_inv, + local_updated_amax, + ) = ActLuPrimitive.impl( + x, + scale, + amax, + out_dtype=out_dtype, + act_enum=act_enum, + act_len=act_len, + scaling_mode=scaling_mode, + is_2x=is_2x, + scale_dtype=scale_dtype, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, out_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return ( local_x, @@ -491,13 +552,28 @@ def shardy_sharding_rule( scaling_mode, is_2x, scale_dtype, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, value_types, result_types, ): - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types, act_params + del ( + out_dtype, + act_enum, + act_len, + scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, + mesh, + result_types, + ) prefix = "ActLu_" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] @@ -526,6 +602,7 @@ def shardy_sharding_rule( ( x_axes, ("…1",), + amax, ), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), **scale_rules.factor_sizes, @@ -543,8 +620,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer, act_params - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10, 11) + # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -553,6 +630,7 @@ def abstract( dz_aval, x_aval, scale_aval, + amax_aval, *, out_dtype, scaling_mode, @@ -561,13 +639,16 @@ def abstract( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_dact_dbias_quantize_p abstract """ - del act_enum, act_params + del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype @@ -576,6 +657,7 @@ def abstract( f" {x_aval.shape} and act_len {act_len}" ) assert scale_aval.dtype == jnp.float32 + assert amax_aval.dtype == jnp.float32 assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( "Current tensor scaling is not supported for fused dact and quantization. Please do" @@ -655,6 +737,7 @@ def lowering( dz, x, scale, + amax, *, out_dtype, scaling_mode, @@ -663,27 +746,42 @@ def lowering( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_dact_dbias_quantize_p lowering rules """ - del out_dtype, scale_dtype, act_len, is_outer - dz_aval, x_aval, scale_aval = ctx.avals_in + del ( + out_dtype, + scale_dtype, + act_len, + is_outer, + amax_scope, + transpose_batch_sequence, + ) + dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_aval.dtype - assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)( + assert scale_aval.dtype == amax_aval.dtype == jnp.float32 + return ffi.ffi_lowering( + BaseDActLuDBiasQuantizePrimitive.name, + operand_output_aliases={3: 4}, # donate amax buffer to updated_amax + )( ctx, dz, x, scale, + amax, scaling_mode=scaling_mode.value, is_2x=is_2x, is_dbias=is_dbias, act_enum=int(act_enum), act_params=act_params.to_ffi_lowering_dict(), + output_amax_when_no_scaling=output_amax_when_no_scaling, ) @staticmethod @@ -691,6 +789,7 @@ def impl( dz, x, scale, + amax, out_dtype, scaling_mode, is_2x, @@ -698,8 +797,11 @@ def impl( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ te_dact_dbias_quantize_p impl @@ -711,6 +813,7 @@ def impl( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, is_2x=is_2x, @@ -718,8 +821,11 @@ def impl( is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, - is_outer=False, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=False, ) ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -747,17 +853,19 @@ def batcher( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, ): """ to describe batch rules for vmap """ - del is_outer check_valid_batch_dims(batch_dims) assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None - dz, x, scale = batched_args - _, x_bdim, scale_bdim = batch_dims + dz, x, scale, amax = batched_args + _, x_bdim, scale_bdim, _ = batch_dims out_bdims = ( x_bdim, # rowwise output @@ -772,6 +880,7 @@ def batcher( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, is_2x=is_2x, @@ -780,6 +889,10 @@ def batcher( act_enum=act_enum, act_len=act_len, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -793,14 +906,18 @@ def infer_sharding_from_operands( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum, act_params - del scale_dtype, act_len, is_outer + del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling + del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence + x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -869,8 +986,11 @@ def partition( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, arg_infos, result_infos, @@ -937,12 +1057,13 @@ def partition( dbias_sharding, ) - def sharded_impl(dz, x, scale): - (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = ( + def sharded_impl(dz, x, scale, amax): + (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = ( BaseDActLuDBiasQuantizePrimitive.impl( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, is_2x=is_2x, @@ -950,8 +1071,11 @@ def sharded_impl(dz, x, scale): is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, - is_outer=True, act_params=act_params, + output_amax_when_no_scaling=output_amax_when_no_scaling, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + is_outer=True, ) ) if is_dbias: @@ -960,9 +1084,15 @@ def sharded_impl(dz, x, scale): global_dbias = local_dbias if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, x_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias @@ -977,14 +1107,30 @@ def shardy_sharding_rule( is_dbias, act_enum, act_len, - is_outer, act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types, act_params + del ( + out_dtype, + scale_dtype, + act_enum, + act_len, + act_params, + is_outer, + output_amax_when_no_scaling, + mesh, + result_types, + amax_scope, + transpose_batch_sequence, + ) + prefix = "DActLuDBias_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 @@ -1006,7 +1152,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (dz_axes, x_axes, ("…2",)), + (dz_axes, x_axes, ("…2",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), **scale_rules.factor_sizes, ) @@ -1092,6 +1238,8 @@ def act_lu( quantizer: Optional[Quantizer] = None, act_params: Optional[ActivationParams] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -1108,6 +1256,8 @@ def act_lu( If quantizer is provided: A ScaledTensor containing the quantized activated input. """ + # TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl() + # Do the same with dact_dbias_quantize() and layernorm_fwd() act_type_id = ActivationEnum[activation_type].value act_len = len(activation_type) assert x.shape[-2] == act_len, ( @@ -1123,30 +1273,44 @@ def act_lu( return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer, act_params=act_params + f=act_lu, + x=x, + activation_type=activation_type, + quantizer=quantizer, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) + amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,) + if quantizer is None: - out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( + out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) out = out.reshape(output_shape) + # TODO(Phuong): ScaledTensorFactory to create NoScaledTensor out = NoScaleTensor( data=out, - amax=None, + amax=updated_amax if output_amax_when_no_scaling else None, ) return out @@ -1157,6 +1321,9 @@ def act_lu( activation_type=activation_type, quantizer=None, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, _ = _quantize_dbias_impl( out, @@ -1164,6 +1331,7 @@ def act_lu( quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out if isinstance(quantizer, DelayedScaleQuantizer): @@ -1178,14 +1346,18 @@ def act_lu( ) = ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=quantizer.q_dtype, act_enum=act_type_id, act_len=act_len, scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), scale_dtype=quantizer.get_scale_dtype(), - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) quantizer.update(updated_amax) @@ -1209,6 +1381,9 @@ def quantize_dact_dbias( is_dbias: bool = True, quantizer: Optional[Quantizer] = None, act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1232,7 +1407,8 @@ def quantize_dact_dbias( f" {x.shape} and act_len {act_len}" ) - scale = jnp.empty((), jnp.float32) + scale = jnp.empty((1,), jnp.float32) + amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,) act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive if not PrimitiveClass.enabled() or ( @@ -1240,10 +1416,11 @@ def quantize_dact_dbias( ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind( dz, x, scale, + amax, # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset @@ -1253,8 +1430,11 @@ def quantize_dact_dbias( is_dbias=False, act_enum=act_type_id, act_len=act_len, - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) output = output.astype(x.dtype) dbias = None @@ -1263,7 +1443,7 @@ def quantize_dact_dbias( output = NoScaleTensor( data=output, - amax=None, + amax=updated_amax if output_amax_when_no_scaling else None, ) return output, dbias @@ -1275,9 +1455,18 @@ def quantize_dact_dbias( activation_type, quantizer=None, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return _quantize_dbias_impl( - out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out.data, + quantizer, + is_dbias=True, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) is_gated = act_len == 2 @@ -1292,6 +1481,9 @@ def quantize_dact_dbias( quantizer=quantizer, flatten_axis=-2, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) if war_output is not None: return war_output @@ -1304,9 +1496,18 @@ def quantize_dact_dbias( activation_type=activation_type, quantizer=None, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, dbias = _quantize_dbias_impl( - out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out, + is_dbias=is_dbias, + quantizer=quantizer, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, dbias @@ -1320,9 +1521,17 @@ def quantize_dact_dbias( x.astype(jnp.float32), activation_type=activation_type, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) out, dbias = _quantize_dbias_impl( - dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + dgated, + quantizer, + is_dbias=True, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, dbias @@ -1337,6 +1546,7 @@ def quantize_dact_dbias( dz, x, scale, + amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, is_2x=quantizer.is_2x2x(), @@ -1344,8 +1554,11 @@ def quantize_dact_dbias( is_dbias=is_dbias, act_enum=act_type_id, act_len=act_len, - is_outer=True, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise @@ -1375,6 +1588,9 @@ def dact_lu( activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1396,5 +1612,8 @@ def dact_lu( is_dbias=False, quantizer=quantizer, act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return output diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index ef6373688..3ce8a19a7 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -92,7 +92,7 @@ class NormFwdPrimitive(BasePrimitive): name = "te_norm_forward_ffi" multiple_results = True - impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11) + impl_static_args = (5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -100,6 +100,7 @@ class NormFwdPrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, gamma_aval, beta_aval, *, @@ -110,15 +111,27 @@ def abstract( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ LayerNorm fwd inner primitive abstract """ + del amax_scope, transpose_batch_sequence + assert not output_amax_when_no_scaling or ( + scaling_mode == ScalingMode.NO_SCALING.value + and not is_norm_fwd_cudnn_enabled(scaling_mode) + ), ( + f"scaling_mode = {scaling_mode}," + f" use_cudnn_norm_fwd={is_norm_fwd_cudnn_enabled(scaling_mode)}" + ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -220,6 +233,7 @@ def lowering( ctx, x, scale, + amax, gamma, beta, *, @@ -230,16 +244,20 @@ def lowering( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ LayerNorm fwd lowering rules """ - del out_dtype, scale_dtype, is_outer - x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in + del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence + x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape @@ -251,10 +269,14 @@ def lowering( assert g_shape == b_shape sm_margin = get_forward_sm_margin() - return ffi.ffi_lowering(NormFwdPrimitive.name)( + return ffi.ffi_lowering( + NormFwdPrimitive.name, + operand_output_aliases={2: 4}, # amax <-> updated_amax + )( ctx, x, scale, + amax, gamma, beta, norm_type=norm_type.value, @@ -263,12 +285,14 @@ def lowering( sm_margin=sm_margin, scaling_mode=scaling_mode.value, is_2x=is_2x, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) @staticmethod def impl( x, scale, + amax, gamma, beta, norm_type, @@ -278,6 +302,9 @@ def impl( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ @@ -297,6 +324,7 @@ def impl( ) = NormFwdPrimitive.inner_primitive.bind( x, scale, + amax, gamma, beta, norm_type=norm_type, @@ -306,6 +334,9 @@ def impl( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=False, ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -341,16 +372,18 @@ def batcher( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ to describe batch rules for vmap """ - del is_outer check_valid_batch_dims(batch_dims) assert NormFwdPrimitive.outer_primitive is not None - x, scale, gamma, beta = batched_args - x_bdim, scale_bdim, _, _ = batch_dims + x, scale, amax, gamma, beta = batched_args + x_bdim, scale_bdim, _, _, _ = batch_dims out_bdims = ( x_bdim, # rowwise output @@ -363,8 +396,9 @@ def batcher( ) return ( NormFwdPrimitive.outer_primitive.bind( - scale, x, + scale, + amax, gamma, beta, norm_type=norm_type, @@ -374,6 +408,10 @@ def batcher( scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -387,15 +425,19 @@ def infer_sharding_from_operands( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, result_infos, ): del zero_centered_gamma, epsilon, out_dtype, result_infos - del scale_dtype, is_outer + del scale_dtype, is_outer, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( @@ -415,9 +457,9 @@ def infer_sharding_from_operands( mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_inv_spec = amax_spec = (None,) + scale_inv_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec + scale_inv_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec @@ -445,6 +487,9 @@ def partition( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, @@ -453,8 +498,9 @@ def partition( del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) - g_spec = get_padded_spec(arg_infos[2]) - b_spec = get_padded_spec(arg_infos[3]) + amax_spec = get_padded_spec(arg_infos[2]) + g_spec = get_padded_spec(arg_infos[3]) + b_spec = get_padded_spec(arg_infos[4]) out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: @@ -485,9 +531,9 @@ def partition( mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_inv_spec = amax_spec = (None,) + scale_inv_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec + scale_inv_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec @@ -499,10 +545,10 @@ def partition( arg_shardings = list(arg_i.sharding for arg_i in arg_infos) # Enforce no sharding of hidden dim for x, gamma and beta arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x") - arg_shardings[2] = NamedSharding( + arg_shardings[3] = NamedSharding( mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma" ) - arg_shardings[3] = NamedSharding( + arg_shardings[4] = NamedSharding( mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta" ) arg_shardings = tuple(arg_shardings) @@ -516,19 +562,20 @@ def partition( rsigma_sharding, ) - def sharded_impl(x, scale, gamma, beta): + def sharded_impl(x, scale, amax, gamma, beta): # expect tp and dp giving same shape, or tp being same shape as global ( local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, - local_amax, + local_updated_amax, local_mu, local_rsigma, ) = NormFwdPrimitive.impl( x, scale, + amax, gamma, beta, norm_type=norm_type, @@ -538,12 +585,21 @@ def sharded_impl(x, scale, gamma, beta): scaling_mode=scaling_mode, is_2x=is_2x, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, x_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return ( local_x, @@ -566,6 +622,9 @@ def shardy_sharding_rule( scaling_mode, is_2x, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, value_types, @@ -576,6 +635,9 @@ def shardy_sharding_rule( epsilon, out_dtype, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, result_types, @@ -594,7 +656,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (x_axes, ("…1",), ("…2",), ("…3",)), + (x_axes, ("…1",), amax, ("…2",), ("…3",)), ( out, colwise_out, @@ -882,6 +944,8 @@ def layernorm_fwd( epsilon: float, quantizer: Optional[Quantizer], amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: """Layer normalization forward pass with optional quantization. @@ -896,6 +960,7 @@ def layernorm_fwd( epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -918,10 +983,12 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) + amax = jnp.zeros((1,), dtype=jnp.float32) if quantizer is None: - output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( + output, _, _, _, updated_amax, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.LayerNorm, @@ -931,18 +998,37 @@ def layernorm_fwd( scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, + amax_scope=amax_scope, + transpose_batch_sequence=False, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) - return NoScaleTensor(data=output, amax=None), mu, rsigma + # cuDNN does not support amax output for non quantized output + updated_amax = ( + updated_amax + if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING) + else None + ) + return NoScaleTensor(data=output, amax=updated_amax), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, mu, rsigma = layernorm_fwd( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=False, + ) + out, _ = _quantize_dbias_impl( + out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence ) - out, _ = _quantize_dbias_impl(out, quantizer) return out, mu, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -954,6 +1040,9 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, _ = _quantize_dbias_impl( out, @@ -961,6 +1050,7 @@ def layernorm_fwd( quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, mu, rsigma @@ -979,6 +1069,7 @@ def layernorm_fwd( ) = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.LayerNorm, @@ -988,6 +1079,9 @@ def layernorm_fwd( scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) quantizer.update(updated_amax) @@ -1091,7 +1185,9 @@ def rmsnorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], - amax_scope: AmaxScope = AmaxScope.LOCAL, + amax_scope: AmaxScope = AmaxScope.TPSP, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: """Root mean square normalization forward pass with optional quantization. @@ -1104,6 +1200,7 @@ def rmsnorm_fwd( epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -1127,12 +1224,14 @@ def rmsnorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) + amax = jnp.zeros((1,), dtype=jnp.float32) beta = jnp.ones((1,), dtype=jnp.float32) if quantizer is None: - output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind( + output, _, _, _, updated_amax, _, rsigma = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.RMSNorm, @@ -1142,16 +1241,39 @@ def rmsnorm_fwd( scaling_mode=ScalingMode.NO_SCALING.value, is_2x=False, scale_dtype=jnp.float32, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) - return NoScaleTensor(data=output, amax=None), rsigma + # cuDNN does not support amax output for non quantized output + updated_amax = ( + updated_amax + if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING) + else None + ) + return NoScaleTensor(data=output, amax=updated_amax), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): - out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out.data, quantizer) + out, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=False, + ) + out, _ = _quantize_dbias_impl( + out.data, + quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) return out, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1162,13 +1284,17 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, _ = _quantize_dbias_impl( - out.data, + out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, rsigma @@ -1187,6 +1313,7 @@ def rmsnorm_fwd( ) = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.RMSNorm, @@ -1196,6 +1323,9 @@ def rmsnorm_fwd( scaling_mode=quantizer.scaling_mode.value, is_2x=is_2x2x, scale_dtype=quantizer.get_scale_dtype(), + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) quantizer.update(updated_amax) @@ -1294,6 +1424,7 @@ def normalization_fwd( norm_type: str, quantizer: Optional[Quantizer], amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, ): """Common wrapper for normalization forward pass. @@ -1311,6 +1442,7 @@ def normalization_fwd( - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -1336,6 +1468,7 @@ def normalization_fwd( epsilon, quantizer, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) elif norm_type == "rmsnorm": assert ( @@ -1348,6 +1481,7 @@ def normalization_fwd( epsilon, quantizer, amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) mu = None else: diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 9f9e8fec0..38fd50a00 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -543,6 +543,18 @@ class AmaxScope(Enum): TPSP = 2 FSDP = 3 + def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): + """Reduce the amax based on its scope""" + gmesh = global_mesh_resource() + sequence_dim = 0 if transpose_batch_sequence else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: + return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + # Run AR across FSDP + if self is AmaxScope.FSDP: + return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + return amax + class AmaxCalculationPrimitive(BasePrimitive): """ @@ -554,7 +566,7 @@ class AmaxCalculationPrimitive(BasePrimitive): impl_static_args = ( 1, 2, - ) # amax_scope, batch_sequence_transpose + ) # amax_scope, transpose_batch_sequence inner_primitive = None outer_primitive = None @@ -563,12 +575,12 @@ def abstract( x_aval, *, amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, ): """ amax calcuation abstract """ - del amax_scope, batch_sequence_transpose + del amax_scope, transpose_batch_sequence dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] @@ -580,19 +592,19 @@ def abstract( def impl( x, amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, ): """ amax calcuation implementation """ - del amax_scope, batch_sequence_transpose + del amax_scope, transpose_batch_sequence amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) return amax @staticmethod def infer_sharding_from_operands( amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, mesh, arg_infos, result_infos, @@ -600,7 +612,7 @@ def infer_sharding_from_operands( """ amax calcuation infer_sharding_from_operands """ - del (amax_scope, batch_sequence_transpose, arg_infos, result_infos) # Unused. + del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. amax_sharding = NamedSharding( mesh, PartitionSpec(None), @@ -611,7 +623,7 @@ def infer_sharding_from_operands( @staticmethod def partition( amax_scope, - batch_sequence_transpose, + transpose_batch_sequence, mesh, arg_infos, result_infos, @@ -631,16 +643,11 @@ def sharded_impl(x): amax = AmaxCalculationPrimitive.impl( x, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh ) - gmesh = global_mesh_resource() - sequence_dim = 0 if batch_sequence_transpose else 1 - # Run AR across TPSP only when tensor-sequence is detected in the input spec - if amax_scope is AmaxScope.TPSP and x_spec[sequence_dim] == gmesh.tpsp_resource: - amax = lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) - # Run AR across FSDP - if amax_scope is AmaxScope.FSDP: - amax = lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) return amax @@ -648,11 +655,11 @@ def sharded_impl(x): return mesh, sharded_impl, amax_sharding, arg_shardings @staticmethod - def shardy_sharding_rule(amax_scope, batch_sequence_transpose, mesh, value_types, result_types): + def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): """ amax calcuation shardy_sharding_rule """ - del amax_scope, batch_sequence_transpose, mesh, result_types + del amax_scope, transpose_batch_sequence, mesh, result_types prefix = "AmaxCal" input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) output_spec = (f"{prefix}_amax",) @@ -709,7 +716,7 @@ def _quantize_dbias_impl( dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -755,12 +762,12 @@ def _quantize_dbias_impl( dq_dtype=dq_dtype, flatten_axis=flatten_axis, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - scale = jnp.empty((), jnp.float32) + scale = jnp.empty((1,), jnp.float32) amax = None if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. @@ -771,7 +778,7 @@ def _quantize_dbias_impl( amax = AmaxCalculationPrimitive.outer_primitive.bind( x.data, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: @@ -845,7 +852,7 @@ def quantize( quantizer: Quantizer, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -866,7 +873,7 @@ def quantize( quantizer=quantizer, flatten_axis=flatten_axis, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) return out @@ -877,7 +884,7 @@ def quantize_dbias( is_dbias: bool = True, flatten_axis: int = -1, amax_scope: AmaxScope = AmaxScope.LOCAL, - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -904,7 +911,7 @@ def quantize_dbias( is_dbias=is_dbias, flatten_axis=flatten_axis, amax_scope=amax_scope, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 0ecf79150..f512321c3 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -15,13 +15,14 @@ namespace transformer_engine { namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, + Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int, ActivationConfig act_params) { + Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, + bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -30,7 +31,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto *output = output_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data(); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto input_dims = input_buf.dimensions(); auto m = product(input_dims, 0, input_dims.size() - 2); @@ -45,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_trans_shape = std::vector{static_cast(n), m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } NVTE_CHECK( scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, @@ -55,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); @@ -145,26 +150,29 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") .Attr("is_2x") - .Attr("act_params"), + .Attr("act_params") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, - JAXX_Scaling_Mode scaling_mode, bool is_2x_int, - ActivationConfig act_params) { - return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, output_buf, - colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, - act_enum, scaling_mode, is_2x_int, act_params); + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, + ActivationConfig act_params, bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, + output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, + updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params, + output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, @@ -172,15 +180,17 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") .Attr("is_2x") - .Attr("act_params")); + .Attr("act_params") + .Attr("output_amax_when_no_scaling")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, @@ -246,15 +256,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias, - ActivationConfig act_params) { + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, + bool is_dbias, ActivationConfig act_params, + bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -262,7 +274,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto *input = input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data(); float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -305,13 +319,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); @@ -440,6 +455,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Arg() // input .Arg() // act input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -451,19 +467,22 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Attr("act_enum") .Attr("is_2x") .Attr("is_dbias") - .Attr("act_params"), + .Attr("act_params") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); Error_Type DActLuDBiasQuantizeInitializeFFI( cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, - bool is_dbias, ActivationConfig act_params) { + Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params, + bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, - act_input_buf, scale_buf, output_buf, colwise_output_buf, - scale_inv_buf, colwise_scale_inv_buf, amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params); + act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, + scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, + workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params, + output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -473,18 +492,20 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Arg() // input .Arg() // act input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") .Attr("is_2x") .Attr("is_dbias") - .Attr("act_params")); + .Attr("act_params") + .Attr("output_amax_when_no_scaling")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 523819392..378e009c8 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); + output_tensor.set_amax(nullptr, DType::kFloat32, std::vector{1}); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { @@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si } Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, - Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, - int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { + Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, + double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, + bool is_2x, bool output_amax_when_no_scaling) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -77,9 +79,12 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *output = output_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data(); auto *mu = mu_buf->untyped_data(); - auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); + auto _norm_type = static_cast(norm_type); auto _is_2x = static_cast(is_2x); @@ -106,6 +111,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } NVTE_CHECK( scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, @@ -123,8 +132,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - nvte_memset(amax, 0, sizeof(float), stream); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); } if (_is_2x) { @@ -162,13 +169,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Ctx() // stream .Arg() // x .Arg() // scale + .Arg() // amax .Arg() // gamma .Arg() // beta .Ret() // output .Ret() // colwise_output .Ret() // scale_inv .Ret() // colwise_scale_inv - .Ret() // amax + .Ret() // updated_amax .Ret() // mu .Ret() // rsigma .Ret() // wkspace @@ -177,21 +185,25 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("is_2x") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type gamma_buf, Buffer_Type beta_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type mu_buf, - Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, + Buffer_Type amax_buf, Buffer_Type gamma_buf, + Buffer_Type beta_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, int64_t sm_margin, - JAXX_Scaling_Mode scaling_mode, bool is_2x) { - return wrapInStreamCapture( - std::function(NormForwardFFI), stream, x_buf, scale_buf, gamma_buf, beta_buf, output_buf, - colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, amax_buf, mu_buf, rsigma_buf, - wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, scaling_mode, is_2x); + JAXX_Scaling_Mode scaling_mode, bool is_2x, + bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, + gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, + colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, + wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, + scaling_mode, is_2x, output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, @@ -199,13 +211,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ .Ctx() // stream .Arg() // x .Arg() // scale + .Arg() // amax .Arg() // gamma .Arg() // beta .Ret() // output .Ret() // colwise_output .Ret() // scale_inv .Ret() // colwise_scale_inv - .Ret() // amax + .Ret() // updated_amax .Ret() // mu .Ret() // rsigma .Ret() // wkspace @@ -214,7 +227,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x")); + .Attr("is_2x") + .Attr("output_amax_when_no_scaling")); pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, NVTE_Norm_Type norm_type, diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index d17d83ec1..05260741b 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -120,9 +120,11 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (is_fp8_dtype(out_dtype)) { if (is_tensor_scaling) { float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + float *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + NVTE_CHECK(amax == updated_amax && amax != nullptr, + "amax must be provided for delayed tensor scaling"); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 3cdf6ba7a..28525a22a 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -63,7 +63,7 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, output_axes: Tuple[str, ...] = None, @@ -81,7 +81,7 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract - batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix output_axes: Logical axes for sharding the output @@ -91,8 +91,8 @@ def dense( Returns: Transformed output tensor """ - if batch_sequence_transpose: - warnings.warn("batch_sequence_transpose is not well tested, use with caution!") + if transpose_batch_sequence: + warnings.warn("transpose_batch_sequence is not well tested, use with caution!") if not get_quantize_config().is_fp8_enabled(): input_dtype = x.dtype @@ -103,7 +103,7 @@ def dense( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -119,7 +119,7 @@ def _dense( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -136,7 +136,7 @@ def _dense( kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification - batch_sequence_transpose: Transpose the batch and sequence dimensions of the input tensor. + transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix @@ -151,7 +151,7 @@ def _dense( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -166,7 +166,7 @@ def _dense_fwd_rule( kernel, bias, contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -197,7 +197,7 @@ def _dense_fwd_rule( flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, amax_scope=AmaxScope.TPSP, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -215,7 +215,7 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set.forward, @@ -240,7 +240,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( contracting_dims, - batch_sequence_transpose, + transpose_batch_sequence, input_axes, kernel_axes, output_axes, @@ -274,7 +274,7 @@ def _dense_bwd_rule( flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, - batch_sequence_transpose=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) # GEMM NT @@ -291,7 +291,7 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op_set.backward, ) @@ -305,7 +305,7 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index f02876d8f..76865f7c1 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -432,6 +432,8 @@ class DenseGeneral(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. """ features: Union[Iterable[int], int] @@ -446,6 +448,7 @@ class DenseGeneral(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 input_axes: Tuple[str, ...] = () + transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: @@ -512,6 +515,7 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -632,6 +636,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float, default = None The factor to scale the output from `DenseGeneral`. It should be a float value or None. When None is set, then no scaling is applied. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. """ features: Union[Iterable[int], int] @@ -657,6 +663,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None + transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: @@ -768,6 +775,7 @@ def __call__(self, inputs: Array) -> Array: dot_input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -775,6 +783,7 @@ def __call__(self, inputs: Array) -> Array: y, kernel, contracting_dims=(axis, contract_ind), + transpose_batch_sequence=self.transpose_batch_sequence, input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, @@ -940,6 +949,8 @@ class LayerNormMLP(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. """ intermediate_dim: int = 2048 @@ -974,6 +985,7 @@ class LayerNormMLP(TransformerEngineBase): dot_2_input_axes: Tuple[str, ...] = None ffn1_ckpt_name: str = "ffn1" ffn2_ckpt_name: str = "ffn2" + transpose_batch_sequence: bool = False def __post_init__(self): if self.kernel_init is None: @@ -1160,6 +1172,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): activation_type=normalized_acts, activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), + transpose_batch_sequence=self.transpose_batch_sequence, ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1178,6 +1191,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) @@ -1188,6 +1202,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -1260,6 +1275,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_2_input_axes, kernel_axes=self.kernel_axes_2, quantizer_set=ffn2_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 868bcfa05..c95765bf3 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1207,6 +1207,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="qkv", dtype=self.dtype, )(inputs_q) @@ -1234,6 +1235,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="query", )(inputs_q) @@ -1252,6 +1254,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): enable_low_rank_adaptation=lora_scope.qkv_proj, low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, + transpose_batch_sequence=self.transpose_batch_sequence, name="kv", dtype=self.dtype, )(inputs_kv) @@ -1292,6 +1295,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="query", )(inputs_q) @@ -2070,6 +2074,7 @@ def hidden_dropout(x, deterministic): layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), + transpose_batch_sequence=self.transpose_batch_sequence, name="mlp", )(mlp_input, deterministic=deterministic) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index fb9783075..136f43df4 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -16,6 +16,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.quantization import AmaxScope from .quantize import ( QuantizerSet, @@ -35,6 +36,7 @@ def layernorm_dense( norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, + transpose_batch_sequence: bool = False, layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, @@ -55,6 +57,7 @@ def layernorm_dense( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix @@ -83,6 +86,7 @@ def layernorm_dense( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -100,6 +104,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -111,6 +116,7 @@ def _layernorm_dense( norm_type: str, zero_centered_gamma: bool, epsilon: float, + transpose_batch_sequence: bool, layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], @@ -131,6 +137,7 @@ def _layernorm_dense( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding quantizer_set: Set of quantizers @@ -147,6 +154,7 @@ def _layernorm_dense( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -164,6 +172,7 @@ def _layernorm_dense_fwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -194,6 +203,8 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) @@ -203,6 +214,8 @@ def _layernorm_dense_fwd_rule( kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -213,6 +226,7 @@ def _layernorm_dense_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -245,6 +259,7 @@ def _layernorm_dense_bwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -285,6 +300,8 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -301,6 +318,7 @@ def _layernorm_dense_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -314,6 +332,7 @@ def _layernorm_dense_bwd_rule( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_constracting_dim, g_constracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 77daa4672..c43430cf3 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -41,7 +41,7 @@ def layernorm_mlp( norm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6, - batch_sequence_transpose: bool = False, + transpose_batch_sequence: bool = False, norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, @@ -78,7 +78,7 @@ def layernorm_mlp( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization - batch_sequence_transpose: Whether to transpose the batch and sequence dimensions + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication @@ -130,7 +130,7 @@ def layernorm_mlp( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -158,7 +158,7 @@ def _layernorm_mlp( norm_type: str, zero_centered_gamma: bool, epsilon: float, - batch_sequence_transpose: bool, + transpose_batch_sequence: bool, norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], @@ -188,7 +188,7 @@ def _layernorm_mlp( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability - batch_sequence_transpose: Whether to transpose the batch and sequence dimensions + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for layernorm sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding @@ -214,7 +214,7 @@ def _layernorm_mlp( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -241,7 +241,7 @@ def _layernorm_mlp_fwd_rule( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -302,11 +302,16 @@ def _layernorm_mlp_fwd_rule( norm_type, quantizer=ffn1_quantizer_set.x, amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, amax_scope=AmaxScope.FSDP + kernel_1, + flatten_axis=-2, + quantizer=ffn1_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) # NN GEMM @@ -315,7 +320,7 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set_1.forward, @@ -345,6 +350,8 @@ def _layernorm_mlp_fwd_rule( if activation_params else None ), + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -353,6 +360,7 @@ def _layernorm_mlp_fwd_rule( kernel_2, quantizer=ffn2_quantizer_set.kernel, amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) # NN GEMM @@ -361,7 +369,7 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, collective_op=collective_op_set_2.forward, @@ -403,7 +411,7 @@ def _layernorm_mlp_bwd_rule( norm_type, zero_centered_gamma, epsilon, - batch_sequence_transpose, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -465,6 +473,7 @@ def _layernorm_mlp_bwd_rule( is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -482,7 +491,7 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op_set_2.backward, ) @@ -498,7 +507,7 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -513,6 +522,8 @@ def _layernorm_mlp_bwd_rule( if activation_params else None ), + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -530,7 +541,7 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, collective_op=collective_op_set_1.backward, ) @@ -542,7 +553,7 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - transpose_batch_sequence=batch_sequence_transpose, + transpose_batch_sequence=transpose_batch_sequence, ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) From 76bced540eb264a194b8cd28f8894f860d841e6a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 7 Oct 2025 12:17:24 -0700 Subject: [PATCH 038/141] `NVFP4BlockScaling` recipe docs (#2241) * Improve docstring for NVFP4 recipe Signed-off-by: Kirthi Shankar Sivamani * Add NVFP4BlockScaling to recipe docs Signed-off-by: Kirthi Shankar Sivamani * Grammar Signed-off-by: Kirthi Shankar Sivamani * improve wording Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Przemyslaw Tredak Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/common/recipe/__init__.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Przemyslaw Tredak Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- docs/api/common.rst | 2 ++ transformer_engine/common/recipe/__init__.py | 28 +++++++++++++++----- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/docs/api/common.rst b/docs/api/common.rst index 541118985..3edd7cae2 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -12,6 +12,8 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) +.. autoapiclass:: transformer_engine.common.recipe.NVFP4BlockScaling(fp4_format=Format.E2M1) + .. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID) .. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1a9b02987..1204c37c5 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -401,16 +401,32 @@ class NVFP4BlockScaling(Recipe): computed from the high precision input to avoid double quantization errors. + The default NVFP4 training recipe implements 3 techniques for quantizing + to a narrow format (4-bit): + + - For weight tensors a variant of the NVFP4 quantization is used, + where a single scaling factor is shared by a 2D block of 16x16 elements. + - When quantizing gradients, stochastic rounding is applied to avoid the bias + introduced by quantization. With this, values are rounded probabilistically + to one of their two nearest representable numbers, with probabilities + inversely proportional to their distances. + - When quantizing inputs and gradients, random Hadamard transforms are applied + (16x16 Hadamard matrix) to smooth outliers in the tensor distributions + and make them easier to represent accurately in NVFP4. + + These techniques are described more comprehensively in the NVFP4 paper titled + 'Pretraining Large Language Models with NVFP4' (https://arxiv.org/abs/2509.25149v1). + Parameters ---------- fp4_format : {Format.E2M1}, default = Format.E2M1 FP4 data type. - fp8_format : {Format.E4M3}, default = Format.E4M3 - FP8 data type. Only E4M3 is supported. - fp8_dpa: bool, default = `False` - FP8 dot product attention. Not yet supported. - fp8_mha: bool, default = `False` - FP8 multi-head attention. Not yet supported. + disable_rht : bool, default = `False` + If set to `True`, random Hadamard transforms are not applied to any tensor. + disable_stochastic_rounding : bool, default = `False` + If set to `True`, stochastic rounding is disabled during quantization for all tensors. + disable_2d_quantization : bool, default = `False` + If set to `True`, 1D block scaling with block size 16 is used for all tensors. """ # Configuration envvars From ac5e868f143401f04664b8cb8f39d806ac912078 Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Tue, 7 Oct 2025 12:51:54 -0700 Subject: [PATCH 039/141] Skip fp8 tests on unsupported devices (#2243) Signed-off-by: Vladimir Cherepanov --- tests/cpp_distributed/test_comm_gemm.cu | 31 +++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu index 8355d5f96..884faa474 100644 --- a/tests/cpp_distributed/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) { return supported; } +int GetDeviceComputeCapability(int device_id) { + int major{}; + int minor{}; + CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id)); + CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id)); + return major * 10 + minor; +} + +template +bool IsDTypeSupported(int /* device_id */) { + return true; +} + +template <> +bool IsDTypeSupported(int device_id) { + return GetDeviceComputeCapability(device_id) >= 89; +} + +template <> +bool IsDTypeSupported(int device_id) { + return GetDeviceComputeCapability(device_id) >= 89; +} + +template +bool AllDTypesSupported(int device_id) { + return (IsDTypeSupported(device_id) && ...); +} + template std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, size_t nsize, size_t ld) { @@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam { template void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + if (!AllDTypesSupported(rank_)) + GTEST_SKIP() << "FP8 is not supported on device " << rank_; + cudaStream_t stream{}; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); From 66f9b3cbae214d521ac18883fe9a386b8893b179 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 7 Oct 2025 21:32:47 -0700 Subject: [PATCH 040/141] [PyTorch] Unblock fused bgrad quantization path for nvfp4 (#2246) Unblock path for fusing NVFP4 quantize and bgrad Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/base.py | 4 +--- transformer_engine/pytorch/module/layernorm_mlp.py | 5 +---- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3ae389568..838ac5281 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -40,7 +40,6 @@ from ..constants import dist_group_type from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.storage.float8_tensor_storage import Float8TensorStorage @@ -1229,8 +1228,7 @@ def grad_output_preprocess( ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: - # TODO(ksivaman): Re-add fusion once kernel is available. - if isinstance(quantizer, (Float8BlockQuantizer, NVFP4Quantizer)): + if isinstance(quantizer, Float8BlockQuantizer): # unfuse bgrad for now until cast_transpose + dgrad calculation is ready for Float8BlockQuantizer. grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2097f01b1..d680a9f8f 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1037,11 +1037,8 @@ def fc2_wgrad_gemm( if ctx.fp8: # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now - # TODO(ksivaman): Re-add fusion once kernel is available. if ( - isinstance( - ctx.fc1_grad_output_quantizer, (Float8BlockQuantizer, NVFP4Quantizer) - ) + isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) or ctx.fp8_recipe.custom() ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) From af2a0c16ec11363c0af84690cd877a59f898820e Mon Sep 17 00:00:00 2001 From: Hua Huang Date: Wed, 8 Oct 2025 08:30:56 -0700 Subject: [PATCH 041/141] [JAX] Async issuing D2H memcpy for grouped_gemm group_sizes array (#2213) * Try async copy of grouped GEMM group_sizes data Signed-off-by: Hua Huang --------- Signed-off-by: Hua Huang Co-authored-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 10 ++- transformer_engine/jax/cpp_extensions/gemm.py | 87 ++++++++++++++++++- transformer_engine/jax/csrc/extensions.h | 1 + .../jax/csrc/extensions/gemm.cpp | 81 +++++++++++++++-- .../jax/csrc/extensions/pybind.cpp | 3 + 5 files changed, 172 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 7a4fa268a..124e0248b 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1366,14 +1366,22 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) + num_gemms = input_shape[0] + _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( + group_sizes, + num_gemms=num_gemms, + ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + prim_out = jax.jit( + tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") + )( lhs, rhs, group_sizes, contracting_dims, + use_async_d2h_group_sizes=True, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 865efe89d..7fe433bcc 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -58,6 +58,7 @@ "collective_gemm_bootstrap", "noop_collective_op_set", "gemm", + "grouped_gemm_copy_group_sizes", "grouped_gemm", "gemm_uses_jax_dot", "sanitize_dims", @@ -1237,6 +1238,63 @@ def _te_gemm( ) +class GroupedGemmCopySizesPrimitive(BasePrimitive): + """ + Primitive for async copying group sizes from device to host + """ + + name = "te_grouped_gemm_d2h_group_sizes_ffi" + multiple_results = False + impl_static_args = (1,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + group_sizes_aval, + *, + num_gemms, + ): + del num_gemms + out_aval = group_sizes_aval + return out_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs) + return out + + @staticmethod + def lowering( + ctx, + group_sizes, + num_gemms, + ): + return jax.ffi.ffi_lowering( + GroupedGemmCopySizesPrimitive.name, + operand_output_aliases={0: 0}, # Mark num_gemms as the output + )( + ctx, + group_sizes, + num_gemms=num_gemms, + ) + + @staticmethod + def impl( + group_sizes, + num_gemms, + ): + assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + +register_primitive(GroupedGemmCopySizesPrimitive) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -1244,7 +1302,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -1267,6 +1325,7 @@ def abstract( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): """ Grouped GEMM operation. @@ -1294,7 +1353,7 @@ def abstract( A jnp.ndarray containing the result of the grouped GEMM operation """ del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval - del K, lhs_is_trans, rhs_is_trans, has_bias + del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_alignment_padding = 256 @@ -1341,6 +1400,7 @@ def lowering( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( @@ -1354,6 +1414,7 @@ def lowering( scaling_mode=scaling_mode.value, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @staticmethod @@ -1374,6 +1435,7 @@ def impl( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): assert GroupedGemmPrimitive.inner_primitive is not None (out, _) = GroupedGemmPrimitive.inner_primitive.bind( @@ -1393,6 +1455,7 @@ def impl( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return (out,) @@ -1661,6 +1724,24 @@ def gemm( return clean_outputs +def grouped_gemm_copy_group_sizes( + group_sizes: jnp.ndarray, + num_gemms: int, +) -> jnp.ndarray: + """ + Async copy group sizes from device to host + + Args: + group_sizes: 1D array containing the sizes of each group + num_gemms: number of grouped gemm calls to be made + """ + out = GroupedGemmCopySizesPrimitive.outer_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -1671,6 +1752,7 @@ def grouped_gemm( preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + use_async_d2h_group_sizes: bool = False, ) -> jnp.ndarray: """ Grouped GEMM operation. @@ -1854,5 +1936,6 @@ def grouped_gemm( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return out diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index bbfc62120..3ce6dee73 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -135,6 +135,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); // Cudnn helpers diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f2007efcf..993ec1377 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -284,12 +284,71 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("collective_op"), FFI_CudaGraph_Traits); +size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, + int32_t *host_group_sizes) { + static std::once_flag init_flag; + static cudaEvent_t d2h_event; + static size_t host_num_gemms; + static const size_t max_num_gemms = 1024; + //static int32_t host_group_sizes_internal[max_num_gemms]; + static int32_t *host_group_sizes_internal = nullptr; + auto init = [&]() { + NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event)); + NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms)); + }; + std::call_once(init_flag, init); + + NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr, + "Only one of dev_group_sizes and host_group_sizes can be non-nullptr."); + + if (dev_group_sizes != nullptr) { + NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ", + "supported number ", max_num_gemms, " to be downloaded in advance."); + host_num_gemms = num_gemms; + // Wait for current compute stream to finish + cudaStream_t compute_stream_0 = nvte_get_compute_stream(0); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event)); + // Async copy group_sizes from device to host + size_t copy_bytes = sizeof(int32_t) * num_gemms; + NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes, + cudaMemcpyDeviceToHost, compute_stream_0)); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0)); + return num_gemms; + } + + if (host_group_sizes != nullptr) { + if (host_num_gemms == 0) return 0; + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the previous value ", host_num_gemms, "."); + // Wait for the async copy to finish, then copy group_sizes to user buffer + // Note: This may break cudaGraph. + NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event)); + memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms); + return host_num_gemms; + } +} + +Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes, + Result_Type dummy_output, size_t num_gemms) { + int32_t *dev_group_sizes = reinterpret_cast(group_sizes.untyped_data()); + GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // group_sizes + .Ret() // dummy_output + .Attr("num_gemms")); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad) { + bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -410,11 +469,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); if (!is_grouped_dense_wgrad) { NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, @@ -673,7 +739,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad")); + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 23d46b338..f6b1acd43 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -69,6 +69,9 @@ pybind11::dict Registrations() { pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM + dict["te_grouped_gemm_d2h_group_sizes_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler)); dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); From e37e33e12768a9fa397b51cd17c6425775c543ea Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 8 Oct 2025 19:19:51 -0700 Subject: [PATCH 042/141] Disallow pure E5M2 recipe for `Float8BlockScaling` (#2251) Catch unsupported GEMM during recipe init Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/recipe/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 1204c37c5..f70b43a7a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -363,6 +363,7 @@ def __post_init__(self) -> None: assert ( not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: return ( From 9bf4175f6b100219e0e02f4ca50d9d8fa5331efe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 8 Oct 2025 19:20:33 -0700 Subject: [PATCH 043/141] [PyTorch] Deprecate old `float8_tensor.py` (#2250) Deprecate old float8_tensor.py Signed-off-by: Kirthi Shankar Sivamani --- .../attention/dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 2 +- .../dot_product_attention/dot_product_attention.py | 2 +- .../pytorch/attention/dot_product_attention/utils.py | 2 +- .../pytorch/attention/multi_head_attention.py | 2 +- transformer_engine/pytorch/float8_tensor.py | 10 ++++++++++ transformer_engine/pytorch/utils.py | 6 +++--- 7 files changed, 18 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 3a1375838..0ddb261d2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -29,7 +29,7 @@ prepare_for_saving, restore_from_saved, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d0ddae25e..d1374e949 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -20,7 +20,7 @@ FusedAttnBackend, ) from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 88e28e3d8..df96067d6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -30,7 +30,7 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ea7b0e876..8b26a1760 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -35,8 +35,8 @@ META_DP, ) from transformer_engine.pytorch.attention.inference import InferenceParams -from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer, ) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index b2f1ff1ac..8f0183224 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -10,7 +10,7 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index a771e3bb7..eeafc23c7 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -4,6 +4,16 @@ """Tensor class with FP8 data""" +import warnings + from .tensor.float8_tensor import Float8Tensor +warnings.warn( + "transformer_engine.pytorch.float8_tensor is deprecated and will be removed" + " in a future release. Float8Tensor should be imported directly through " + "`from transformer_engine.pytorch import Float8Tensor`", + DeprecationWarning, + stacklevel=2, +) + __all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 8ea362371..b1a7e3731 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -184,7 +184,7 @@ def combine_tensors( num_tensors = len(tensors) new_shape = list(tensors[0].shape) new_shape.insert(dim, num_tensors) - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor if isinstance(tensors[0], Float8Tensor): new_stride = list(tensors[0]._data.stride()) @@ -224,7 +224,7 @@ def forward( # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( Float8TensorStorage, ) @@ -278,7 +278,7 @@ def backward(ctx, *grad_outputs): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) split_dim = (ctx.split_dim + dims) % dims - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor if isinstance(grad_outputs[0], Float8Tensor): noop_ok = True From e99be1b6af1aa138194c211ad9952858b3aaee44 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 9 Oct 2025 10:17:41 -0700 Subject: [PATCH 044/141] Update minimum python version to 3.10 and add checks in CI (#2247) * Update minimum python version to 3.10 and update CI Signed-off-by: Kirthi Shankar Sivamani * review Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .pre-commit-config.yaml | 6 ++++++ build_tools/utils.py | 19 +++++++++++++++++++ setup.py | 3 ++- transformer_engine/jax/setup.py | 3 ++- transformer_engine/pytorch/setup.py | 3 ++- 5 files changed, 31 insertions(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbe486fac..5043d6ea2 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,3 +38,9 @@ repos: entry: clang-format -i args: ["-style=file"] files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$ + + - repo: https://github.com/netromdk/vermin + rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3 + hooks: + - id: vermin + args: ['-t=3.10', '--violations'] diff --git a/build_tools/utils.py b/build_tools/utils.py index 3d8ec462c..296f928b7 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -12,12 +12,31 @@ import shutil import subprocess import sys +import platform from pathlib import Path from importlib.metadata import version as get_version from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union +# Needs to stay consistent with .pre-commit-config.yaml config. +def min_python_version() -> Tuple[int]: + """Minimum supported Python version.""" + return (3, 10, 0) + + +def min_python_version_str() -> str: + """String representing minimum supported Python version.""" + return ".".join(map(str, min_python_version())) + + +if sys.version_info < min_python_version(): + raise RuntimeError( + f"Transformer Engine requires Python {min_python_version_str()} or newer, " + f"but found Python {platform.python_version()}." + ) + + @functools.lru_cache(maxsize=None) def debug_build_enabled() -> bool: """Whether to build with a debug configuration""" diff --git a/setup.py b/setup.py index ed1f5b8a9..c932da5e0 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ cuda_version, get_frameworks, remove_dups, + min_python_version_str, ) frameworks = get_frameworks() @@ -190,7 +191,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", + python_requires=f">={min_python_version_str()}", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index ca83cf465..f83375d82 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -44,7 +44,7 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers +from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements @@ -100,6 +100,7 @@ description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, + python_requires=f">={min_python_version_str()}", install_requires=install_requirements(), tests_require=test_requirements(), ) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 46543acf2..08870040f 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -45,7 +45,7 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import copy_common_headers +from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( setup_pytorch_extension, @@ -152,6 +152,7 @@ def run(self): description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, + python_requires=f">={min_python_version_str()}", install_requires=install_requirements(), tests_require=test_requirements(), ) From 8a7ab3ddc17e275fcbcd2eee8688ada265efbcad Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 9 Oct 2025 13:49:12 -0700 Subject: [PATCH 045/141] [JAX] NVFP4 support in TE/JAX (#2254) Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- examples/jax/encoder/common.py | 13 +- .../run_test_multiprocessing_encoder.sh | 2 + .../encoder/test_model_parallel_encoder.py | 118 +++-- examples/jax/encoder/test_multigpu_encoder.py | 84 ++- .../encoder/test_multiprocessing_encoder.py | 80 ++- .../jax/encoder/test_single_gpu_encoder.py | 57 +- examples/jax/mnist/test_single_gpu_mnist.py | 10 +- tests/jax/test_custom_call_compute.py | 495 +++++++++++++++--- tests/jax/test_distributed_layernorm_mlp.py | 133 +++-- tests/jax/test_helper.py | 69 ++- tests/jax/utils.py | 6 + transformer_engine/jax/__init__.py | 3 +- .../jax/cpp_extensions/__init__.py | 1 + .../jax/cpp_extensions/activation.py | 10 +- transformer_engine/jax/cpp_extensions/amax.py | 420 +++++++++++++++ transformer_engine/jax/cpp_extensions/gemm.py | 134 ++++- transformer_engine/jax/cpp_extensions/misc.py | 15 +- .../jax/cpp_extensions/normalization.py | 15 +- .../jax/cpp_extensions/quantization.py | 414 ++++++++------- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/amax.cpp | 100 ++++ .../jax/csrc/extensions/ffi.cpp | 3 + transformer_engine/jax/csrc/extensions/ffi.h | 2 + .../jax/csrc/extensions/gemm.cpp | 56 +- .../jax/csrc/extensions/misc.cpp | 20 +- transformer_engine/jax/csrc/extensions/misc.h | 31 +- .../jax/csrc/extensions/pybind.cpp | 11 +- .../jax/csrc/extensions/quantization.cpp | 189 +++++-- transformer_engine/jax/dense.py | 2 +- transformer_engine/jax/flax/module.py | 60 +-- transformer_engine/jax/layernorm_dense.py | 2 +- transformer_engine/jax/layernorm_mlp.py | 2 +- transformer_engine/jax/quantize/__init__.py | 1 + .../jax/quantize/dequantizer.py | 97 +++- transformer_engine/jax/quantize/hadamard.py | 72 +++ transformer_engine/jax/quantize/helper.py | 376 ++++++++++--- transformer_engine/jax/quantize/metadata.py | 16 +- transformer_engine/jax/quantize/quantizer.py | 278 +++++++++- .../jax/quantize/scaling_modes.py | 267 +++++++++- transformer_engine/jax/quantize/tensor.py | 38 +- 40 files changed, 2987 insertions(+), 721 deletions(-) create mode 100644 transformer_engine/jax/cpp_extensions/amax.py create mode 100644 transformer_engine/jax/csrc/extensions/amax.cpp create mode 100644 transformer_engine/jax/quantize/hadamard.py diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index a8bf25113..772d5f4c1 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -33,6 +33,13 @@ def is_mxfp8_supported(): return gpu_arch >= 100 +@lru_cache +def is_nvfp4_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 100 + + def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False): """Checks whether most params are sharded across sharding axis. @@ -98,7 +105,7 @@ def assert_leaf_sharding(path, arr): ) -def get_fp8_recipe_from_name_string(name: str): +def get_quantization_recipe_from_name_string(name: str): """Query recipe from a given name string""" match name: case "DelayedScaling": @@ -107,5 +114,7 @@ def get_fp8_recipe_from_name_string(name: str): return recipe.MXFP8BlockScaling() case "Float8CurrentScaling": return recipe.Float8CurrentScaling() + case "NVFP4BlockScaling": + return recipe.NVFP4BlockScaling() case _: - raise ValueError(f"Invalid fp8_recipe, got {name}") + raise ValueError(f"Invalid quantization_recipe, got {name}") diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 2a979e177..fa7102cb4 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -10,9 +10,11 @@ TEST_CASES=( "test_te_delayed_scaling_fp8" "test_te_current_scaling_fp8" "test_te_mxfp8" +"test_te_nvfp4" "test_te_bf16_shardy" "test_te_delayed_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy" +"test_te_nvfp4_shardy" ) : ${TE_PATH:=/opt/transformerengine} diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 41832650f..5fc7efbba 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -21,13 +21,13 @@ from common import ( is_bf16_supported, - get_fp8_recipe_from_name_string, + get_quantization_recipe_from_name_string, assert_params_sufficiently_sharded, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DEVICE_DP_AXIS = "data" @@ -36,6 +36,7 @@ NAMED_TP_AXIS = "my_tp_axis" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +SR_KEY = "sr_rng" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -121,6 +122,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -135,11 +138,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -150,7 +153,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn): +def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -159,11 +162,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -223,7 +228,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -257,7 +262,7 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -275,7 +280,8 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] @@ -355,7 +361,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -367,22 +380,24 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs ) print( @@ -402,16 +417,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for training (default: 128)", + help="input batch size for training (default: 256)", ) parser.add_argument( "--test-batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for testing (default: 128)", + help="input batch size for testing (default: 256)", ) parser.add_argument( "--max-seq-len", @@ -466,8 +481,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 5 epochs for testing""" @@ -477,7 +493,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -485,7 +501,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.361 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -493,14 +509,22 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -509,7 +533,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -518,14 +542,23 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_with_sp(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -534,7 +567,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp_shardy(self): @@ -544,24 +577,27 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.361 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_with_sp_shardy(self): """Test Transformer Engine with MXFP8 + SP""" self.args.enable_shardy = True @@ -569,7 +605,17 @@ def test_te_mxfp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_with_sp_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index bc6a56752..68fb3ddd3 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -19,17 +19,18 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import is_bf16_supported, get_quantization_recipe_from_name_string import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" DROPOUT_KEY = "dropout" +SR_KEY = "sr_rng" INPUT_KEY = "input_rng" @@ -97,6 +98,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -111,11 +114,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -126,7 +129,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn): +def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -135,11 +138,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -199,7 +204,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -254,7 +259,7 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -270,6 +275,7 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) + rng, sr_rng = jax.random.split(rng) init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} input_shape = [args.batch_size, args.max_seq_len] @@ -322,7 +328,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -334,22 +347,24 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs ) print( @@ -369,16 +384,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=256, + default=512, metavar="N", - help="input batch size for training (default: 256)", + help="input batch size for training (default: 512)", ) parser.add_argument( "--test-batch-size", type=int, - default=256, + default=512, metavar="N", - help="input batch size for testing (default: 256)", + help="input batch size for testing (default: 512)", ) parser.add_argument( "--max-seq-len", @@ -430,8 +445,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 5 epochs for testing""" @@ -441,7 +457,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -449,7 +465,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -457,7 +473,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.749 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -465,6 +481,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) + assert actual[0] < 0.51 and actual[1] > 0.75 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @@ -472,7 +496,7 @@ def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -481,7 +505,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8_shardy(self): @@ -490,18 +514,24 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.749 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) + assert actual[0] < 0.51 and actual[1] > 0.75 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index abf6a407b..358fbca4b 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -25,7 +25,8 @@ is_bf16_supported, is_fp8_supported, is_mxfp8_supported, - get_fp8_recipe_from_name_string, + is_nvfp4_supported, + get_quantization_recipe_from_name_string, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex @@ -39,6 +40,7 @@ NAMED_TP_AXIS = "my_tp_axis" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +SR_KEY = "sr_rng" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -175,6 +177,8 @@ def train_epoch( epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_input = sentence[perm, ...] batch_mask = mask[perm, ...] batch_label = label[perm, ...] @@ -200,11 +204,11 @@ def train_epoch( return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels, 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -216,7 +220,16 @@ def loss_fn(var_collect, disable_dropout=False): def eval_model( - state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec + state, + test_ds, + batch_size, + var_collect, + eval_fn, + mesh, + inputs_pspec, + masks_pspec, + labels_pspec, + rngs, ): """Evaluation loop.""" global_input_shape, input_named_sharding, sentence = shard_array_wrapper( @@ -233,7 +246,8 @@ def eval_model( all_accuracy = [] for batch_input, batch_mask, batch_label in zip(sentence, mask, label): - + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} shard_input = jax.make_array_from_single_device_arrays( global_input_shape, input_named_sharding, [batch_input] ) @@ -244,7 +258,7 @@ def eval_model( global_label_shape, label_named_sharding, [batch_label] ) - loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect) + loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -303,7 +317,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -372,7 +386,7 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -390,7 +404,8 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] @@ -444,7 +459,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -456,14 +478,16 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") else: for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, @@ -488,6 +512,7 @@ def train_and_evaluate(args): inputs_pspec, masks_pspec, labels_sharding.spec, + rngs, ) if args.process_id == 0: print( @@ -508,16 +533,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for training (default: 128)", + help="input batch size for training (default: 256)", ) parser.add_argument( "--test-batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for testing (default: 128)", + help="input batch size for testing (default: 256)", ) parser.add_argument( "--max-seq-len", @@ -629,7 +654,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.43 and result[1] > 0.80 + assert result[0] < 0.432 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -639,6 +664,14 @@ def test_te_mxfp8(self): result = self.exec(True, "MXFP8BlockScaling") assert result[0] < 0.43 and result[1] > 0.80 + @unittest.skipIf( + not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" + ) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + result = self.exec(True, "NVFP4BlockScaling") + assert result[0] < 0.451 and result[1] > 0.79 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" @@ -659,19 +692,24 @@ def test_te_delayed_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 + assert result[0] < 0.432 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" ) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) assert result[0] < 0.43 and result[1] > 0.80 + @unittest.skipIf( + not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" + ) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) + assert result[0] < 0.451 and result[1] > 0.79 + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 826d0d2fc..320483099 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -16,14 +16,15 @@ from flax import linen as nn from flax.training import train_state -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import is_bf16_supported, get_quantization_recipe_from_name_string import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode PARAMS_KEY = "params" DROPOUT_KEY = "dropout" +SR_KEY = "sr_rng" INPUT_KEY = "input_rng" @@ -92,6 +93,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -107,11 +110,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): @jax.jit -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -122,7 +125,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect): +def eval_model(state, test_ds, batch_size, var_collect, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -131,11 +134,15 @@ def eval_model(state, test_ds, batch_size, var_collect): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_step( + state, batch_inputs, batch_masks, batch_labels, var_collect, rngs + ) all_loss.append(loss) all_accuracy.append(accuracy) @@ -195,7 +202,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -208,14 +215,15 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -238,21 +246,25 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect ) - test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, rngs + ) print( f"Epoch: {epoch:>2} " @@ -329,8 +341,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 3 epochs for testing""" @@ -340,7 +353,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.452 and actual[1] > 0.788 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -348,7 +361,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.457 and actual[1] > 0.784 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -356,7 +369,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.461 and actual[1] > 0.784 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -364,7 +377,15 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.457 and actual[1] > 0.784 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.476 and actual[1] > 0.775 if __name__ == "__main__": diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 92baf4b0c..81bea4a32 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -18,11 +18,11 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DIR = str(Path(__file__).resolve().parents[1]) sys.path.append(str(DIR)) -from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string +from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string IMAGE_H = 28 IMAGE_W = 28 @@ -189,7 +189,7 @@ def train_and_evaluate(args): label_shape = [args.batch_size] if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None @@ -308,8 +308,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 124e0248b..2934e48df 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,11 +40,13 @@ QuantizerFactory, QuantizeLayout, noop_quantizer_set, + should_use_rht, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense +from transformer_engine.common import recipe GEMM_CASES = [ (256, 256, 512), @@ -56,16 +58,23 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] -is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available() -is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) -supported_scaling_modes = [] +# TODO(Phuong): remove unneccessary pytest skips +is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.DELAYED_TENSOR_SCALING +) +is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.MXFP8_1D_SCALING +) +is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.NVFP4_1D_SCALING +) + """ Find supported scaling modes""" -if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) - supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) -if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) +supported_scaling_modes = helper.get_supported_scaling_modes() +non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling] +supported_recipes = helper.get_supported_quantization_recipes() +supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] def is_shape_supported_by_mxfp8(input_shape): @@ -83,12 +92,13 @@ def assert_bitwise_scaled_tensors( a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True ): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): - if not precise_comparison: + if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling: assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) return assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype + assert a.data_layout == b.data_layout if a.scaling_mode.is_tensor_scaling(): # Assert in dq_dtype as some unfused codepaths have an intermediate cast # to an input dtype which reduces precision compared to everything in fp32 @@ -96,6 +106,16 @@ def assert_bitwise_scaled_tensors( elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Compare MXFP8 scales as uint8 assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) + elif a.scaling_mode.is_nvfp4_scaling: + assert_allclose(a.amax, b.amax) + assert_allclose(a.scale_inv, b.scale_inv) + if not precise_comparison: + mismatch = a.data != b.data + mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32)) + assert ( + mismatch_fraction < 0.05 + ), f"Mismatch fraction {mismatch_fraction} is too high" + return else: raise ValueError(f"Unsupported scaling mode {a.scaling_mode}") assert_allclose(a.data, b.data) @@ -603,10 +623,24 @@ def test_norm_forward_with_block_scaling_fp8( ) -QUANTIZE_OUTPUT_DTYPES = { +QUANTIZE_OUTPUT_FP8_DTYPES = { "L0": [jnp.float8_e4m3fn], "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], } +QUANTIZE_OUTPUT_DTYPES = { + test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn] + for test_level in QUANTIZE_OUTPUT_FP8_DTYPES +} +QUANTIZE_QDTYPE_AND_SCALING_MODES = { + test_level: [ + (q_dtype, scaling_mode) + for q_dtype, scaling_mode in zip( + QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes + ) + if q_dtype in scaling_mode.get_compatible_q_dtypes() + ] + for test_level in QUANTIZE_OUTPUT_FP8_DTYPES +} ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [ ((32, 64), -1), @@ -615,8 +649,7 @@ def test_norm_forward_with_block_scaling_fp8( ((32, 256, 128), -1), ((32, 256, 128), -2), ((64, 32, 32, 256), -1), - ((64, 32, 32, 256), -2), - ((64, 32, 32, 256), -3), + ((8192, 2, 4096), -2), ] QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = { @@ -636,18 +669,38 @@ def test_norm_forward_with_block_scaling_fp8( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn]) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper( - "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + "q_layout", + [ + QuantizeLayout.ROWWISE, + QuantizeLayout.COLWISE, + QuantizeLayout.ROWWISE_COLWISE, + ], ) class TestQuantize: """ Purely quantization related tests that will always test on a wider set of types and shapes """ + def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + """Temporary hack to skip unsupported FP4 cases until we implement them""" + if q_dtype not in scaling_mode.get_compatible_q_dtypes(): + pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") + return + + # HACK: FIXME TODO(jberchtold) + row = reduce(operator.mul, input_shape[flatten_axis:], 1) + col = reduce(operator.mul, input_shape[:flatten_axis], 1) + will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout) + if will_use_rht and (row % 64 != 0 or col % 128 != 0): + pytest.skip("Unfused RHT is not supported currently, skipping") + def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + key = jax.random.PRNGKey(0) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) @@ -657,6 +710,68 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt q_layout=q_layout, ) + if scaling_mode.is_nvfp4_scaling: + if in_dtype != jnp.bfloat16: + pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently") + return + q_func = _jax_quantize + # For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor. + x = jax.random.uniform(key, input_shape, in_dtype) * 10 + q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis) + + dq_rowwise = None + dq_colwise = None + if isinstance(q1, ScaledTensor1x): + dq = q1.dequantize() + if q1.is_colwise: + dq_colwise = dq + else: + dq_rowwise = dq + elif isinstance(q1, ScaledTensor2x): + dq_rowwise = q1.rowwise_tensor.dequantize() + dq_colwise = q1.colwise_tensor.dequantize() + else: + raise ValueError(f"Unsupported output type {type(q1)}") + + # We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization. + if dq_rowwise is not None: + assert ( + dq_rowwise.shape == x.shape + ), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}" + q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis) + q2_rowwise = ( + q2_rowwise + if isinstance(q2_rowwise, ScaledTensor1x) + else q2_rowwise.rowwise_tensor + ) + q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor + assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise) + + if dq_colwise is not None: + # Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape + flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis + colwise_flatten_axis = len(input_shape) - flatten_axis + dq_colwise = jnp.transpose( + dq_colwise, + (*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)), + ) + assert ( + dq_colwise.shape == x.shape + ), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}" + q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis) + q2_colwise = ( + q2_colwise + if isinstance(q2_colwise, ScaledTensor1x) + else q2_colwise.colwise_tensor + ) + q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor + assert_bitwise_scaled_tensors(q1_colwise, q2_colwise) + + assert ( + dq_rowwise is not None or dq_colwise is not None + ), "At least one of rowwise or colwise dq must be not None" + return + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -664,9 +779,33 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis) assert_dequantized_scaled_tensor(scaled_tensor, x) + def _should_use_precise_comparison( + self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ): + # TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values. + RHT_SLIGHT_MISMATCH_SHAPES = [ + ((32, 256, 128), -1), + ((64, 32, 32, 256), -1), + ((8192, 2, 4096), -2), + ] + + if ( + should_use_rht(scaling_mode, q_layout=q_layout) + and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES + ): + # TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes + return False + + if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: + # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation + return False + + return True + def test_quantize_bitwise( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): + self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) @@ -677,15 +816,202 @@ def test_quantize_bitwise( jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) - assert_bitwise_scaled_tensors(te_output, jax_output) + try: + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + except AssertionError as e: + if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16: + error_message = e.args[0] + if "RHT requires input to be bfloat16" in error_message: + # Successfully caught the expected error, early return from the test + return + raise e + + assert_bitwise_scaled_tensors( + te_output, + jax_output, + precise_comparison=self._should_use_precise_comparison( + in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ), + ) + + def test_quantize_bitwise_jitted( + self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis + ): + self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + + te_quantizer, jax_quantizer = QuantizerFactory.create( + n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout + ) + + jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) + te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) + + jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) + + try: + te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + except AssertionError as e: + if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16: + error_message = e.args[0] + if "RHT requires input to be bfloat16" in error_message: + # Successfully caught the expected error, early return from the test + return + raise e + + assert_bitwise_scaled_tensors( + te_output, + jax_output, + precise_comparison=self._should_use_precise_comparison( + in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ), + ) + + +@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) +@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) +@pytest_parametrize_wrapper( + "scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling] +) +@pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] +) +class TestStochasticRounding: + + def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]: + """Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor.""" + if isinstance(scaled_tensor, ScaledTensor1x): + dq = scaled_tensor.dequantize() + if scaled_tensor.data_layout == "T": + dq = jnp.transpose( + dq, + ( + *range(scaled_tensor.flatten_axis, dq.ndim), + *range(scaled_tensor.flatten_axis), + ), + ) + return [dq] + elif isinstance(scaled_tensor, ScaledTensor2x): + [rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor) + [colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor) + return [rowwise_dq, colwise_dq] + raise ValueError( + "Unsupported ScaledTensor type, expected ScaledTensor but received" + f" {type(scaled_tensor)}" + ) + + def _sample_sr_qdq( + self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> list[jnp.ndarray]: + """Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors.""" + dq_tensors = [] + + key = jax.random.PRNGKey(0) + + for i in range(num_samples): + iter_key = jax.random.fold_in(key, i) + sr_rng_state = jax.random.randint( + iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 + ) + quantizer = QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + stochastic_rounding_rng_state=sr_rng_state, + ) + + q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) + iter_dq = self._dequantize(q_output) + dq_tensors.extend(iter_dq) + + avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0) + assert avg_sr_tensor.shape == inputs.shape, ( + f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" + f" {inputs.shape}" + ) + + sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) + + dq_var = jnp.var(jnp.stack(dq_tensors)) + assert ( + dq_var > 0 + ), "Variance of dequantized tensors is zero, stochastic rounding may not be working" + + return dq_tensors + + def _round_nearest( + self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> jnp.ndarray: + """Quantizes and dequantizes the input tensor with round nearest quantization.""" + quantizer = QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + stochastic_rounding_rng_state=None, + ) + q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) + rn_dq = self._dequantize(q_output)[0] + return rn_dq + + def _test_sr( + self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> float: + """Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples.""" + dq_tensors = self._sample_sr_qdq( + num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0) + assert avg_sr_tensor.shape == inputs.shape, ( + f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" + f" {inputs.shape}" + ) + + round_nearest_tensor = self._round_nearest( + q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + + sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) + rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs)) + + assert sr_mae < rn_mae, ( + f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than" + f" round nearest ({rn_mae})" + ) + + return sr_mae + + def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" + # HACK: FIXME TODO(jberchtold) + row = reduce(operator.mul, input_shape[flatten_axis:], 1) + col = reduce(operator.mul, input_shape[:flatten_axis], 1) + will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout) + if will_use_rht and (row % 64 != 0 or col % 128 != 0): + pytest.skip("Unfused RHT is not supported currently, skipping") + + key = jax.random.PRNGKey(0) + inputs = jax.random.uniform(key, input_shape, in_dtype) + + NUM_SAMPLES = 10 + + te_mean_error = self._test_sr( + NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + jax_mean_error = self._test_sr( + NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + + assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) -@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) +@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @pytest_parametrize_wrapper("with_group_sizes", [True, False]) @pytest_parametrize_wrapper( @@ -724,7 +1050,6 @@ def test_grouped_qdq( q_layout=q_layout, n_groups=n_groups, ) - scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) @@ -736,9 +1061,8 @@ def test_grouped_qdq( class TestFusedQuantize: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @@ -860,7 +1184,7 @@ def test_quantize_dact_dbias_no_quantization( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] @@ -886,7 +1210,7 @@ def test_quantize_dact_dbias_tensor_scaling( @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] ) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] @@ -919,6 +1243,11 @@ def test_quantize_dact_dbias_mxfp8_scaling( (jnp.float8_e4m3fn, jnp.float8_e5m2), ] +supported_nvfp4_scaling_mode_pairs = [ + (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING), + (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING), +] + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): @@ -960,7 +1289,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): @@ -994,6 +1323,40 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) + # TODO(Phuong): add bitwise test + @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) + def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm): + x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T" + w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N" + if x_uses_rht != w_uses_rht: + # TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT + pytest.skip("RHT must be used for both or neither operand, skipping") + + lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + lhs_quantizer = QuantizerFactory.create( + scaling_mode=lhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + ) + rhs_quantizer = QuantizerFactory.create( + scaling_mode=rhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + ) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): data_layout = "NN" @@ -1019,11 +1382,10 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)]) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -1044,14 +1406,9 @@ def ref_func(x, w, bias, data_layout): value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True, - ) + quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( @@ -1062,10 +1419,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp.float8_e5m2) + assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.fixture(name="random_inputs") @@ -1087,11 +1444,11 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) + @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): + def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ @@ -1108,12 +1465,7 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) - quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True, - ) + quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) if norm_type == "layernorm": beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) @@ -1148,7 +1500,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( @@ -1158,22 +1510,22 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) - assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype) + assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm + self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule @@ -1201,10 +1553,7 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, - scaling_mode=scaling_mode, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2 if scaling_mode.is_tensor_scaling() else jnp.float8_e4m3fn, - is_2x2x=True, + fp8_recipe=recipe, ) if norm_type == "layernorm": @@ -1251,7 +1600,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( @@ -1272,18 +1621,16 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) - - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp.float8_e5m2) - if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp.float8_e5m2) - - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp.float8_e5m2) + fwd_dtype = quantizer_sets[0].x.q_dtype + bwd_dtype = quantizer_sets[0].dgrad.q_dtype + assert_allclose(prim_out, ref_out, dtype=fwd_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype) + assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp.float8_e5m2) - - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp.float8_e5m2) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp.float8_e5m2) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype) # E5M2 * E5M2 is not supported @@ -1388,7 +1735,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): fwd_dtype, bwd_dtype = fwd_bwd_dtype @@ -1469,7 +1816,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): "fwd_bwd_dtype", [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)], ) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d38f43d00..bf78ed3bb 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,6 +1,7 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import re from typing import Callable, Sequence, Union, Optional import pytest @@ -17,7 +18,11 @@ ) from transformer_engine.common import recipe -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import ( + is_fp8_available, + ScalingMode, + get_quantize_config_with_recipe, +) from transformer_engine.jax import fp8_autocast from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.layernorm_mlp import layernorm_mlp @@ -33,19 +38,20 @@ W_JOINED_AXES, ) from transformer_engine.jax.sharding import MeshResource -from transformer_engine.jax.quantize import QuantizerFactory +from transformer_engine.jax.quantize import ( + QuantizerFactory, + get_supported_quantization_recipes, + is_scaling_mode_supported, +) from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability -is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) +is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) -SUPPORTED_RECIPES = [] -if is_fp8_supported: - SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) - SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) -if is_mxfp8_supported: - SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) +SUPPORTED_RECIPES = get_supported_quantization_recipes() +SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES] DTYPES = [jnp.bfloat16, jnp.float16] INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in] @@ -141,6 +147,7 @@ def layernorm_fp8_mlp_prim_func( layernorm_type: str = "rmsnorm", activation_type: Sequence[Union[str, Callable]] = ("gelu",), multi_gpus: bool = False, + quantization_recipe: recipe.Recipe = None, ) -> jnp.ndarray: if multi_gpus: @@ -154,7 +161,9 @@ def layernorm_fp8_mlp_prim_func( dot_1_input_axes = dot_2_input_axes = None kernel_1_axes = kernel_2_axes = None - quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) + quantizer_sets = QuantizerFactory.create_set( + n_quantizer_sets=2, fp8_recipe=quantization_recipe + ) # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 return jnp.mean( @@ -182,7 +191,7 @@ def _test_layernorm_mlp_grad( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, use_shardy, with_jax_gemm, ): @@ -202,7 +211,9 @@ def _test_layernorm_mlp_grad( # Single GPU with fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + mesh_resource=MeshResource(), ): single_jitter = jax.jit( value_and_grad_func, @@ -214,7 +225,9 @@ def _test_layernorm_mlp_grad( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + mesh_resource=mesh_resource, ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) @@ -254,10 +267,16 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn - bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 + fwd_test_type = bwd_test_type = dtype + if quantization_recipe is not None: + quantize_config = get_quantize_config_with_recipe(quantization_recipe) + fwd_test_type = quantize_config.FWD_DTYPE + bwd_test_type = quantize_config.BWD_DTYPE - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + if fwd_test_type == jnp.float16 and use_bias: + assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5) + else: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) for i in range(len(inputs)): if multi_grads[i] is not None: @@ -278,13 +297,12 @@ def _test_layernorm_mlp_grad( err_msg=f"multi_grads[{i}] is not close", ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -293,27 +311,28 @@ def test_layernorm_mlp_grad( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, with_jax_gemm, ): + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, use_shardy=False, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, @@ -322,18 +341,18 @@ def test_layernorm_mlp_grad_shardy( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, with_jax_gemm, ): - if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=True, with_jax_gemm=with_jax_gemm, ) @@ -346,7 +365,7 @@ def _test_layernorm_mlp( input_shape, dtype, use_fp8, - fp8_recipe, + quantization_recipe, use_shardy, with_jax_gemm, ): @@ -355,14 +374,16 @@ def _test_layernorm_mlp( layernorm_type = "rmsnorm" rng = jax.random.PRNGKey(0) - subkeys = jax.random.split(rng, 2) + subkeys = jax.random.split(rng, 3) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) - init_rngs = {"params": subkeys[1]} + init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]} with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with fp8_autocast( + enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=MeshResource() + ): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, intermediate_dim=INTERMEDIATE, @@ -371,7 +392,7 @@ def _test_layernorm_mlp( ) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( - params_single, x, deterministic=True + params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]} ) # Multi GPUs @@ -379,7 +400,7 @@ def _test_layernorm_mlp( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( layernorm_type=layernorm_type, @@ -399,7 +420,7 @@ def _test_layernorm_mlp( ) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( - params_sharded, x, deterministic=True + params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]} ) # Make sure params values are the same @@ -411,8 +432,8 @@ def _test_layernorm_mlp( rtol = None l40_tolerance_update = ( get_min_device_compute_capability() == 89 - and fp8_recipe == recipe.DelayedScaling() and use_fp8 + and quantization_recipe.delayed() and dtype == jnp.float16 and activation_type == ("gelu",) ) @@ -430,8 +451,8 @@ def _test_layernorm_mlp( # within tolerance to the float32 ground truth. jax_triton_gemm_precision_tolerance_update = ( with_jax_gemm - and fp8_recipe is not None - and (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()) + and quantization_recipe is not None + and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling()) and dtype in (jnp.bfloat16, jnp.float16) and activation_type == ("gelu", "linear"), ) @@ -457,22 +478,30 @@ def test_layernorm_mlp_layer( input_shape, dtype, use_fp8=False, - fp8_recipe=None, + quantization_recipe=None, use_shardy=False, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + quantization_recipe, + with_jax_gemm, ): + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp( mesh_config, activation_type, @@ -480,7 +509,7 @@ def test_layernorm_mlp_layer_fp8( input_shape, dtype, use_fp8=True, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=False, with_jax_gemm=with_jax_gemm, ) @@ -501,24 +530,30 @@ def test_layernorm_mlp_layer_shardy( input_shape, dtype, use_fp8=False, - fp8_recipe=None, + quantization_recipe=None, use_shardy=True, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + quantization_recipe, + with_jax_gemm, ): - if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp( mesh_config, activation_type, @@ -526,7 +561,7 @@ def test_layernorm_mlp_layer_fp8_shardy( input_shape, dtype, use_fp8=True, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=True, with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index e4511e1fe..e9f71a32f 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -10,20 +10,27 @@ import numpy as np from utils import assert_allclose -from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + MXFP8BlockScaling, + Float8CurrentScaling, + NVFP4BlockScaling, +) from transformer_engine.common.recipe import Format as FP8Format -from transformer_engine.jax import fp8_autocast, get_delayed_scaling +from transformer_engine.jax import fp8_autocast from transformer_engine.jax.quantize import ( get_quantize_config, - is_fp8_available, + is_scaling_mode_supported, ScalingMode, update_collections, TensorSource, ) +from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.sharding import MeshResource, global_mesh_resource -is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) +is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) class TestHelper(unittest.TestCase): @@ -52,14 +59,16 @@ class TestFP8Functions(unittest.TestCase): def _check_default_state(self): self.assertFalse(get_quantize_config().is_fp8_enabled()) - def _compare_delay_scaling(self, ref, test): - self.assertTrue(ref.margin == test.margin) - self.assertTrue(ref.fp8_format == test.fp8_format) - self.assertTrue(ref.amax_history_len == test.amax_history_len) - self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) + def _compare_delay_scaling(self, test): + self.assertEqual(get_quantize_config().MARGIN, test.margin) + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) + self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len) + self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) for tensor_source in TensorSource: self.assertEqual( get_quantize_config().get_scaling_mode(tensor_source), @@ -67,13 +76,26 @@ def _compare_current_scaling(self, test): ) def _compare_mxfp8_scaling(self, test): - self.assertEqual(get_quantize_config().MARGIN, test.margin) - self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) for tensor_source in TensorSource: self.assertEqual( get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING ) + def _compare_nvfp4_scaling(self, test): + self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0]) + self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1]) + for tensor_source in TensorSource: + target_scaling_mode = ( + ScalingMode.NVFP4_2D_SCALING + if tensor_source == TensorSource.KERNEL + else ScalingMode.NVFP4_1D_SCALING + ) + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode + ) + @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_delayed_scaling(self): self._check_default_state() @@ -86,14 +108,14 @@ def test_fp8_autocast_delayed_scaling(self): ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) + self._compare_delay_scaling(ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) + self._compare_delay_scaling(ds) self._check_default_state() @@ -133,16 +155,27 @@ def test_fp8_autocast_mxfp8_block_scaling(self): self._check_default_state() - bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) + bs = MXFP8BlockScaling() with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() - bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) + @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) + def test_fp8_autocast_nvfp4_block_scaling(self): + self._check_default_state() + + with fp8_autocast( + enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource() + ): + self._check_default_state() + + self._check_default_state() + + bs = NVFP4BlockScaling() with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_mxfp8_scaling(bs) + self._compare_nvfp4_scaling(bs) self._check_default_state() diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 8ad6dccfe..c28e68a15 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1544,6 +1544,12 @@ def dtype_tols( rtol = eps_relaxed if atol is None: atol = max(ulp, eps_relaxed) + + # Manually set tols for nvfp4 + if dtype == jnp.float4_e2m1fn: + atol = 0.05 + rtol = 0.1 + return {"rtol": rtol, "atol": atol} diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 0b5e43402..354a1293e 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,7 +34,7 @@ from . import flax from . import quantize -from .quantize import fp8_autocast, update_collections, get_delayed_scaling +from .quantize import fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME from .sharding import MeshResource @@ -47,7 +47,6 @@ "NVTE_FP8_COLLECTION_NAME", "fp8_autocast", "update_collections", - "get_delayed_scaling", "MeshResource", "flax", "quantize", diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index ef8d76cd0..c0285e157 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for c++ extensions""" from .activation import * +from .amax import * from .attention import * from .normalization import * from .quantization import * diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index be1f9f956..bb3c56bcf 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1314,7 +1314,10 @@ def act_lu( ) return out - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( x=x, @@ -1488,7 +1491,10 @@ def quantize_dact_dbias( if war_output is not None: return war_output - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( dz=dz, diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py new file mode 100644 index 000000000..2f3bc402e --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -0,0 +1,420 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for amax calculation""" +from enum import Enum + + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.experimental.custom_partitioning import SdyShardingRule +from jax.sharding import PartitionSpec + +from .base import BasePrimitive, register_primitive +from .misc import ( + get_padded_spec, + NamedSharding, +) +from ..sharding import ( + global_mesh_resource, + lax_paral_op, +) +from ..quantize import ( + get_wgrad_sign_vector, + get_sign_from_vector, +) + + +__all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"] + + +class AmaxScope(Enum): + """ + Amax Scope Enum + """ + + LOCAL = 1 + TPSP = 2 + FSDP = 3 + + def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): + """Reduce the amax based on its scope""" + gmesh = global_mesh_resource() + sequence_dim = 0 if transpose_batch_sequence else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: + return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + # Run AR across FSDP + if self is AmaxScope.FSDP: + return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + return amax + + +class AmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning + """ + + name = "jax_local_amax" + multiple_results = False + impl_static_args = ( + 1, + 2, + ) # amax_scope, transpose_batch_sequence + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + transpose_batch_sequence, + ): + """ + amax calcuation abstract + """ + del amax_scope, transpose_batch_sequence + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + return out_aval + + @staticmethod + def impl( + x, + amax_scope, + transpose_batch_sequence, + ): + """ + amax calcuation implementation + """ + del amax_scope, transpose_batch_sequence + amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) + return amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + transpose_batch_sequence, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding + + @staticmethod + def partition( + amax_scope, + transpose_batch_sequence, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculation.amax_sharding", + ) + + def sharded_impl(x): + amax = AmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh + ) + + return amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, amax_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): + """ + amax calcuation shardy_sharding_rule + """ + del amax_scope, transpose_batch_sequence, mesh, result_types + prefix = "AmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_spec = (f"{prefix}_amax",) + return SdyShardingRule((input_spec,), (output_spec,)) + + +register_primitive(AmaxCalculationPrimitive, outer_only=True) + + +class RHTAmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning for calculating regular and post-Random Hadamard Transform (RHT) amax using TE's fused kernels. + """ + + name = "te_rht_amax_ffi" + multiple_results = True + impl_static_args = ( + 1, # amax_scope + 2, # transpose_batch_sequence + 3, # rht_matrix_random_sign_mask_t + 4, # produce_regular_amax + 5, # flatten_axis + ) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + amax calcuation abstract + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ) + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.bfloat16], f"RHT requires input to be bfloat16, but got {dtype}" + + amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + post_rht_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + return amax_aval, post_rht_amax_aval + + @staticmethod + def lowering( + ctx, + x, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + te_dbias_quantize_p lowering rules + """ + del amax_scope, transpose_batch_sequence + (x_aval,) = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + flatten_axis = flatten_axis if flatten_axis >= 0 else flatten_axis + len(x_aval.shape) + assert 0 < flatten_axis < len(x_aval.shape), "Flatten axis out of bounds!" + + return ffi.ffi_lowering( + RHTAmaxCalculationPrimitive.name, + )( + ctx, + x, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + + @staticmethod + def impl( + x, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + amax calcuation implementation + """ + assert RHTAmaxCalculationPrimitive.inner_primitive is not None + ( + amax, + post_rht_amax, + ) = RHTAmaxCalculationPrimitive.inner_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + return amax, post_rht_amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + arg_infos, + result_infos, + ) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="RHTAmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding, amax_sharding + + @staticmethod + def partition( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="RHTAmaxCalculationPrimitive.amax_sharding", + ) + out_shardings = (amax_sharding, amax_sharding) + + def sharded_impl(x): + amax, post_rht_amax = RHTAmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh + ) + post_rht_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + post_rht_amax, x_spec, transpose_batch_sequence, mesh + ) + + return amax, post_rht_amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + value_types, + result_types, + ): + """ + amax calcuation shardy_sharding_rule + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + result_types, + ) + prefix = "RHTAmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_amax_spec = (f"{prefix}_amax",) + output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",) + return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec)) + + +register_primitive(RHTAmaxCalculationPrimitive) + + +def calculate_amax(x: jnp.ndarray, amax_scope: AmaxScope, transpose_batch_sequence: bool): + """ + Compute the maximum absolute value (amax) of the input tensor. + """ + assert AmaxCalculationPrimitive.outer_primitive is not None + return AmaxCalculationPrimitive.outer_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + + +def calculate_post_rht_amax( + x: jnp.ndarray, + amax_scope: AmaxScope, + transpose_batch_sequence: bool, + produce_regular_amax: bool, + flatten_axis: int, +): + """Compute the post-Random Hadamard Transform (RHT) amax of the input tensor, and optionally the regular amax. + + Args: + x: Input tensor. + amax_scope: The scope for amax reduction (local, TPSP, or FSDP). + transpose_batch_sequence: Whether the input tensor has its batch and sequence dimensions transposed. + produce_regular_amax: Whether to compute and return the regular amax alongside the post-RHT amax. + flatten_axis: The axis at which to flatten the input tensor before applying RHT. + Returns: + A tuple containing: + - The regular amax if `produce_regular_amax` is True, otherwise None. + - The post-RHT amax. + """ + amax, post_rht_amax = RHTAmaxCalculationPrimitive.outer_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=get_sign_from_vector(get_wgrad_sign_vector()), + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + + if produce_regular_amax: + return amax, post_rht_amax + return None, post_rht_amax diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7fe433bcc..b72161f1a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -32,6 +32,7 @@ AbstractBaseTensor, NoScaleTensor, ScaledTensor, + ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, @@ -43,6 +44,7 @@ noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, + should_use_rht, ) from .misc import get_padded_spec, is_all_reduce_in_float32 from ..sharding import ( @@ -138,6 +140,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ need_lhs_colwise = lhs_is_transposed and ( lhs_quantizer.scaling_mode.is_1d_block_scaling() or not is_fp8_gemm_with_all_layouts_supported() + or lhs_quantizer.scaling_mode.is_nvfp4_scaling ) flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( @@ -153,6 +156,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ need_rhs_colwise = not rhs_is_transposed and ( rhs_quantizer.scaling_mode.is_1d_block_scaling() or not is_fp8_gemm_with_all_layouts_supported() + or rhs_quantizer.scaling_mode.is_nvfp4_scaling ) flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( @@ -165,9 +169,27 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x) + def uses_rht(q: AbstractBaseTensor) -> bool: + return isinstance(q, ScaledTensor1x) and should_use_rht( + q.scaling_mode, is_colwise=q.is_colwise + ) + + # TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class + assert uses_rht(lhs_q) == uses_rht(rhs_q), ( + "With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" + " quantized as well. This is to ensure the RHT is applied to both and will cancel out in" + " the GEMM." + ) + return lhs_q, rhs_q +def _get_nvfp4_tensor_scale_inv(amax): + DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32) + return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) + + def collective_gemm_bootstrap( num_total_devices, num_devices_per_process, @@ -345,7 +367,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -357,6 +379,8 @@ def abstract( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -404,7 +428,9 @@ def _dims_are_consecutive(dims): lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) if scaling_mode != ScalingMode.NO_SCALING: - assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes( + lhs.dtype, rhs.dtype + ), ( "cuBLAS GEMM quantized operands have incompatible data types: " f"{lhs.dtype} x {rhs.dtype}." ) @@ -484,6 +510,8 @@ def _dims_are_consecutive(dims): f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + assert alpha.size == 1 and alpha.dtype == jnp.float32 + assert beta.size == 1 and beta.dtype == jnp.float32 # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() @@ -510,6 +538,8 @@ def lowering( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -530,7 +560,7 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { "scaling_mode": int(scaling_mode.value), "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), @@ -563,6 +593,8 @@ def impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -626,6 +658,8 @@ def impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, @@ -675,6 +709,8 @@ def outer_impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -694,6 +730,8 @@ def outer_impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -1001,6 +1039,9 @@ def partition( gelu_input_specs = (None,) arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + # Alpha, beta + arg_shardings += (none_sharding, none_sharding) + # Assemble output shardings out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] @@ -1014,7 +1055,7 @@ def partition( pre_gelu_specs = (None,) out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) - def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta): # We should not fuse bias in the output reduction case sharded_fuse_bias = fuse_bias and reduce_spec is None outputs = GemmPrimitive.impl( @@ -1024,6 +1065,8 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, @@ -1114,8 +1157,10 @@ def _generate_operand_rules(name, ndim, cdims): rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) bias_spec = rhs_non_cspec if fuse_bias else ("…4",) - dbias_spec = bias_spec if grad else ("…5") - gelu_spec = out_spec if fuse_gelu else ("…6",) + gelu_spec = out_spec if fuse_gelu else ("…5",) + alpha_spec = ("_6",) + beta_spec = ("_7",) + dbias_spec = bias_spec if grad else ("…8") return SdyShardingRule( operand_mappings=( @@ -1125,6 +1170,8 @@ def _generate_operand_rules(name, ndim, cdims): rhs_scale_specs, bias_spec, gelu_spec, + alpha_spec, + beta_spec, ), result_mappings=( out_spec, @@ -1178,6 +1225,7 @@ def _te_gemm( # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + lhs_amax = rhs_amax = None # Extract GEMM custom op inputs from quantized operands if isinstance(lhs_q, ScaledTensor): assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( @@ -1192,6 +1240,7 @@ def _te_gemm( lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_amax = lhs_q.amax if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -1201,7 +1250,11 @@ def _te_gemm( if isinstance(rhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() - assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + assert ( + rhs_q.scaling_mode == lhs_q.scaling_mode + or rhs_q.scaling_mode.is_nvfp4_scaling + and lhs_q.scaling_mode.is_nvfp4_scaling + ), ( "cuBLAS GEMM quantized operands have mismatched scaling types, " f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." ) @@ -1209,6 +1262,15 @@ def _te_gemm( rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_amax = rhs_q.amax + + alpha = jnp.ones((1,), jnp.float32) + beta = jnp.zeros((1,), jnp.float32) + if scaling_mode.is_nvfp4_scaling: + assert lhs_amax is not None and rhs_amax is not None + lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) + rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) + alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -1224,6 +1286,8 @@ def _te_gemm( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=(lhs_cdims, rhs_cdims), scaling_mode=scaling_mode, @@ -1514,15 +1578,17 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): @partial(jax.jit, static_argnums=(2,)) -def _jax_gemm_mxfp8_1d( +def _jax_scaled_matmul( lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] ): """ JAX GEMM for MXFP8 via scaled_matmul """ - assert ( - rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING - ), "rhs does not have MXFP8 1D scaling mode" + assert rhs.scaling_mode in ( + ScalingMode.MXFP8_1D_SCALING, + ScalingMode.NVFP4_1D_SCALING, + ScalingMode.NVFP4_2D_SCALING, + ), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -1537,21 +1603,48 @@ def _jax_gemm_mxfp8_1d( f" {rhs.is_colwise}" ) + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: + out_dtype = lhs.dq_dtype + assert ( + lhs.data_layout == "N" and rhs.data_layout == "N" + ), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + else: + if lhs.data_layout == "T": + lhs_contract = transpose_dims( + lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis + ) + if rhs.data_layout == "T": + rhs_contract = transpose_dims( + rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis + ) + out_dtype = jnp.float32 + # Reshape + Transpose (if needed) # [..., M, K] -> [1, reduce(..., M), K] # [..., K, M] -> [1, reduce(..., M), K] - lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch)) - rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch)) - lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) - rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) + lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch), lhs.data_layout == "T") + rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch), rhs.data_layout == "T") + lhs_scale_3d = _shape_normalization( + lhs.scale_inv, (lhs_contract, lhs_batch), lhs.data_layout == "T" + ) + rhs_scale_3d = _shape_normalization( + rhs.scale_inv, (rhs_contract, rhs_batch), rhs.data_layout == "T" + ) # JAX scaled_matmul only supports NT now (TN-gemm) # * Expected shape: # * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) out_3d = jax.nn.scaled_matmul( - lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype + lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) + if lhs.scaling_mode.is_nvfp4_scaling: + assert lhs.amax is not None and rhs.amax is not None + lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) + rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) + alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv + out_3d = (out_3d * alpha).astype(lhs.dq_dtype) + # Reshape [1, reduce(..., M), N] -> [..., M, N] lhs_remain_shape = tuple( lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract @@ -1560,6 +1653,7 @@ def _jax_gemm_mxfp8_1d( rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract ) out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape) + return out @@ -1575,7 +1669,7 @@ def _jax_gemm( """ dim_nums = (contracting_dims, ((), ())) - def _jax_gemm_fp8_impl(lhs, rhs): + def _jax_gemm_impl(lhs, rhs): if lhs.scaling_mode.is_tensor_scaling(): assert ( rhs.scaling_mode == lhs.scaling_mode @@ -1587,15 +1681,15 @@ def _jax_gemm_fp8_impl(lhs, rhs): ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) - if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) + if lhs.scaling_mode.is_1d_block_scaling: + return _jax_scaled_matmul(lhs, rhs, dim_nums) raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}") lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + return _jax_gemm_impl(lhs_q, rhs_q) if ( isinstance(lhs, jnp.ndarray) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 52f5edbf3..572d82f18 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -6,8 +6,6 @@ import os import functools from typing import Tuple -from importlib.metadata import version as get_pkg_version -from packaging.version import Version as PkgVersion import numpy as np @@ -75,7 +73,8 @@ def jax_dtype_to_te_dtype(jax_dtype): jnp.int64.dtype: TEDType.kInt64, jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, - jnp.uint8.dtype: TEDType.kByte, + jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0, + jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1, } if jax_dtype not in converter: @@ -151,16 +150,6 @@ def get_cudnn_version() -> Tuple[int, int, int]: return (major, minor, patch) -@functools.lru_cache(maxsize=None) -def jax_version_meet_requirement(version: str): - """ - Helper function checking if required JAX version is available - """ - jax_version = PkgVersion(get_pkg_version("jax")) - jax_version_required = PkgVersion(version) - return jax_version >= jax_version_required - - def get_xla_flag(flag: str, default=None, cast=str): """ Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3ce8a19a7..90ab5fb7f 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -28,7 +28,10 @@ get_cudnn_version, ) from .quantization import _quantize_dbias_impl, AmaxScope -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp_tpsp, +) from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, @@ -1031,7 +1034,10 @@ def layernorm_fwd( ) return out, mu, rsigma - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. out, mu, rsigma = layernorm_fwd( x=x, @@ -1276,7 +1282,10 @@ def rmsnorm_fwd( ) return out, rsigma - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. out, rsigma = rmsnorm_fwd( x=x, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 38fd50a00..b3f1e60f9 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,7 +6,6 @@ from functools import reduce from typing import Tuple, Optional, Union import math -from enum import Enum import jax @@ -17,6 +16,7 @@ import transformer_engine_jax +from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, @@ -31,8 +31,7 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, - global_mesh_resource, - lax_paral_op, + num_of_devices, ) from ..quantize import ( ScaledTensor2x, @@ -45,6 +44,8 @@ ScalingMode, compute_scale_from_amax, NoScaleTensor, + get_rht_matrix, + should_use_rht, ) @@ -59,14 +60,16 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): name = "te_dbias_quantize_ffi" multiple_results = True impl_static_args = ( - 3, - 4, - 5, - 6, - 7, - 8, - 9, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer + 6, # out_dtype + 7, # scaling_mode + 8, # q_layout + 9, # flatten_axis + 10, # scale_dtype + 11, # is_dbias + 12, # is_outer + 13, # stochastic_rounding + 14, # use_rht + ) inner_primitive = None outer_primitive = None @@ -75,6 +78,9 @@ def abstract( x_aval, scale_aval, amax_aval, + sr_rng_state_aval, + post_rht_amax_aval, + rht_matrix_aval, *, out_dtype, scaling_mode, @@ -83,6 +89,8 @@ def abstract( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p abstract @@ -91,6 +99,28 @@ def abstract( assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape assert scale_aval is None or scale_aval.dtype == jnp.float32 + if stochastic_rounding: + assert ScalingMode( + scaling_mode + ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes" + # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64 + assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, ( + "sr_rng_state must be a uint32 array when stochastic_rounding is True but" + f" received {sr_rng_state_aval}" + ) + if is_outer: + assert ( + sr_rng_state_aval.shape[0] == num_of_devices() + and sr_rng_state_aval.shape[1] == 4 + ), ( + "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" + f" True and is_outer is True but received {sr_rng_state_aval.shape}" + ) + else: + assert sr_rng_state_aval.shape == (4,), ( + "Sharded sr_rng_state must be of shape (4,) per device when" + f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" + ) if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): rowwise_out_shape = out_shape @@ -98,14 +128,50 @@ def abstract( rowwise_out_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) + assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), ( + f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must" + f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}" + ) + updated_amax_aval = amax_aval + if use_rht: + assert ( + x_aval.dtype == jnp.bfloat16 + ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion." + + if flatten_axis < 0: + flatten_axis += len(x_aval.shape) + rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1) + cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1) + assert rows % 64 == 0 and cols % 128 == 0, ( + "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be" + f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape" + f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}." + ) + + assert ( + rht_matrix_aval is not None + and rht_matrix_aval.dtype == jnp.bfloat16 + and rht_matrix_aval.shape == (16, 16) + ), "rht_matrix must be of shape (16, 16) and dtype bfloat16" + assert ( + post_rht_amax_aval is not None + and post_rht_amax_aval.dtype == jnp.float32 + and post_rht_amax_aval.size == 1 + ), "post_rht_amax must be of dtype float32" + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) + ).get_scale_shape_2x( + x_aval.shape, + is_padded=not is_outer, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -126,6 +192,7 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(scale_dtype), scaling_mode, QuantizeLayout( q_layout @@ -172,6 +239,9 @@ def lowering( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, *, out_dtype, scaling_mode, @@ -180,12 +250,14 @@ def lowering( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p lowering rules """ del out_dtype, scale_dtype, is_outer - x_aval, scale_aval, amax_aval = ctx.avals_in + x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == amax_aval.dtype == jnp.float32 return ffi.ffi_lowering( @@ -196,10 +268,15 @@ def lowering( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, is_dbias=is_dbias, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) @staticmethod @@ -207,6 +284,9 @@ def impl( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype, scaling_mode, q_layout, @@ -214,6 +294,8 @@ def impl( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p implementation @@ -232,6 +314,9 @@ def impl( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -239,10 +324,14 @@ def impl( scale_dtype=scale_dtype, is_dbias=is_dbias, is_outer=False, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis) + ).get_scale_shape_2x( + x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True + ) scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) @@ -271,6 +360,8 @@ def batcher( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ to describe batch rules for vmap @@ -278,8 +369,8 @@ def batcher( del is_outer check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale, amax = batched_args - x_bdim, scale_bdim, amax_bdim = batch_dims + x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args + x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( @@ -287,12 +378,17 @@ def batcher( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, flatten_axis=flatten_axis, scale_dtype=scale_dtype, is_dbias=is_dbias, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ), out_bdims, ) @@ -306,11 +402,20 @@ def infer_sharding_from_operands( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, arg_infos, result_infos, ): - del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. + del ( + out_dtype, + result_infos, + scale_dtype, + is_outer, + stochastic_rounding, + use_rht, + ) # Unused. x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) @@ -320,7 +425,7 @@ def infer_sharding_from_operands( desc="BaseDBiasQuantizePrimitive.out_sharding", ) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -340,11 +445,19 @@ def infer_sharding_from_operands( ) scale_inv_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv_spec = scale_inv_spec + if ( + ScalingMode(scaling_mode).is_block_scaling + and ScalingMode(scaling_mode).is_colwise_transposed + ): + colwise_scale_inv_spec = multidim_transpose( + scale_inv_spec, transpose_axis=flatten_axis + ) + else: + colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" @@ -376,11 +489,13 @@ def partition( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, arg_infos, result_infos, ): - del result_infos, is_outer + del result_infos, is_outer # Unused. x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) @@ -389,8 +504,9 @@ def partition( PartitionSpec(*x_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) + if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -410,11 +526,19 @@ def partition( ) scale_inv_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv_spec = scale_inv_spec + if ( + ScalingMode(scaling_mode).is_block_scaling + and ScalingMode(scaling_mode).is_colwise_transposed + ): + colwise_scale_inv_spec = multidim_transpose( + scale_inv_spec, transpose_axis=flatten_axis + ) + else: + colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" @@ -428,6 +552,7 @@ def partition( desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) + # TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) out_shardings = ( out_sharding, @@ -438,7 +563,7 @@ def partition( dbias_sharding, ) - def sharded_impl(x, scale, amax): + def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): ( local_x, local_colwise_x, @@ -450,6 +575,9 @@ def sharded_impl(x, scale, amax): x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -457,6 +585,8 @@ def sharded_impl(x, scale, amax): scale_dtype=scale_dtype, is_dbias=is_dbias, is_outer=True, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: @@ -489,35 +619,54 @@ def shardy_sharding_rule( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, value_types, result_types, ): - del out_dtype, scale_dtype, is_outer, mesh, result_types + del ( + out_dtype, + scale_dtype, + is_outer, + stochastic_rounding, + use_rht, + mesh, + result_types, + ) prefix = "DBiasQuantize_" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[0].shape, unique_var=prefix + "x", flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, ) x_axes = scale_rules.input_spec - colwise_scale_inv = scale_rules.colwise_rule out = x_axes colwise_out = (prefix + "out_colwise",) + colwise_scale_inv = (prefix + "colwise_scale_inv",) if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + colwise_scale_inv = scale_rules.colwise_rule + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) + colwise_scale_inv = tuple( + multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis) + ) else: colwise_out = x_axes dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) amax = (prefix + "amax",) + sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis") + + post_rht_amax = (prefix + "post_rht_amax",) + rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2") return SdyShardingRule( - (x_axes, ("…1",), amax), + (x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), **scale_rules.factor_sizes, ) @@ -534,141 +683,6 @@ class QuantizePrimitive(BaseDBiasQuantizePrimitive): """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -class AmaxScope(Enum): - """ - Amax Scope Enum - """ - - LOCAL = 1 - TPSP = 2 - FSDP = 3 - - def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): - """Reduce the amax based on its scope""" - gmesh = global_mesh_resource() - sequence_dim = 0 if transpose_batch_sequence else 1 - # Run AR across TPSP only when tensor-sequence is detected in the input spec - if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: - return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) - # Run AR across FSDP - if self is AmaxScope.FSDP: - return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) - return amax - - -class AmaxCalculationPrimitive(BasePrimitive): - """ - Amax Calculation Primitive with custom_partitioning - """ - - name = "jax_local_amax" - multiple_results = False - impl_static_args = ( - 1, - 2, - ) # amax_scope, transpose_batch_sequence - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract( - x_aval, - *, - amax_scope, - transpose_batch_sequence, - ): - """ - amax calcuation abstract - """ - del amax_scope, transpose_batch_sequence - - dtype = dtypes.canonicalize_dtype(x_aval.dtype) - assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - - out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - return out_aval - - @staticmethod - def impl( - x, - amax_scope, - transpose_batch_sequence, - ): - """ - amax calcuation implementation - """ - del amax_scope, transpose_batch_sequence - amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) - return amax - - @staticmethod - def infer_sharding_from_operands( - amax_scope, - transpose_batch_sequence, - mesh, - arg_infos, - result_infos, - ): - """ - amax calcuation infer_sharding_from_operands - """ - del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. - amax_sharding = NamedSharding( - mesh, - PartitionSpec(None), - desc="AmaxCalculationPrimitive.out_sharding", - ) - return amax_sharding - - @staticmethod - def partition( - amax_scope, - transpose_batch_sequence, - mesh, - arg_infos, - result_infos, - ): - """ - amax calcuation partition - """ - del result_infos - x_spec = get_padded_spec(arg_infos[0]) - amax_sharding = NamedSharding( - mesh, - PartitionSpec(None), - desc="AmaxCalculation.amax_sharding", - ) - - def sharded_impl(x): - amax = AmaxCalculationPrimitive.impl( - x, - amax_scope=amax_scope, - transpose_batch_sequence=transpose_batch_sequence, - ) - amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( - amax, x_spec, transpose_batch_sequence, mesh - ) - - return amax - - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - return mesh, sharded_impl, amax_sharding, arg_shardings - - @staticmethod - def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): - """ - amax calcuation shardy_sharding_rule - """ - del amax_scope, transpose_batch_sequence, mesh, result_types - prefix = "AmaxCal" - input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) - output_spec = (f"{prefix}_amax",) - return SdyShardingRule((input_spec,), (output_spec,)) - - -register_primitive(AmaxCalculationPrimitive, outer_only=True) - - def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): @@ -740,7 +754,11 @@ def _quantize_dbias_impl( # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): + is_unsupported = ( + quantizer.q_layout == QuantizeLayout.COLWISE + and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING + ) + if is_unsupported or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -767,15 +785,32 @@ def _quantize_dbias_impl( dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias + use_rht = False + scale = jnp.empty((1,), jnp.float32) - amax = None + post_rht_amax = None + rht_matrix = jnp.empty((1, 1), jnp.bfloat16) + amax = x.amax + + if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout): + use_rht = True + rht_matrix = get_rht_matrix() + + new_amax, post_rht_amax = calculate_post_rht_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + produce_regular_amax=amax is None, + flatten_axis=flatten_axis, + ) + if amax is None: + # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel + # So here we only calculate and update amax when it is not provided from a previous layer (amax is None) + amax = new_amax + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - # Globally reduce amax across all devices for current scaling so we have a single global scale. - # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this - # until the tensor is dequantized (e.g. in the GEMM). - amax = x.amax if amax is None: - amax = AmaxCalculationPrimitive.outer_primitive.bind( + amax = calculate_amax( x.data, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -783,8 +818,17 @@ def _quantize_dbias_impl( scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # Make sure to reset amax to zeros for DelayedScaling + amax = jnp.zeros((1,), jnp.float32) + elif quantizer.scaling_mode.is_nvfp4_scaling: + if amax is None: + amax = calculate_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) - # Make sure amax is init with zero + # Make sure amax is not None if amax is None: amax = jnp.zeros((1,), jnp.float32) @@ -796,9 +840,16 @@ def _quantize_dbias_impl( and is_1x_kernel_supported ) q_layout = quantizer.q_layout + if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE + sr_rng_state = None + if quantizer.scaling_mode.is_nvfp4_scaling: + # Only NVFP4 scaling modes support stochastic rounding + if quantizer.stochastic_rounding_rng_state is not None: + sr_rng_state = quantizer.stochastic_rounding_rng_state + ( rowwise_casted_output, colwise_casted_output, @@ -810,13 +861,18 @@ def _quantize_dbias_impl( x.data, scale, amax, + sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32), + post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), + rht_matrix, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout.value, flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), - is_dbias=is_dbias, + is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_outer=True, + stochastic_rounding=sr_rng_state is not None, + use_rht=use_rht, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): @@ -830,14 +886,17 @@ def _quantize_dbias_impl( colwise_casted_output = jnp.transpose( rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) ) - quantizer.update(updated_amax) + if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias: + dbias = _jax_dbias(x, flatten_axis=flatten_axis) out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, colwise_data=colwise_casted_output, colwise_scale_inv=colwise_scale_inv, + amax=updated_amax, + colwise_amax=post_rht_amax, scaling_mode=quantizer.scaling_mode, dq_dtype=dq_dtype, q_layout=quantizer.q_layout, @@ -955,6 +1014,11 @@ def abstract( # TODO(Phuong): can scale_aval be None? assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), ( + f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must" + f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}" + ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_grouped_scale_shape_2x( diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 3ce6dee73..87c6fa91c 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -85,7 +85,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, + DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, QuantizeLayout q_layout); @@ -138,6 +138,10 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); +// Amax +XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp new file mode 100644 index 000000000..46f167fca --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -0,0 +1,100 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#include + +#include + +#include "../extensions.h" +#include "transformer_engine/cast.h" +#include "transformer_engine/hadamard_transform.h" +#include "transformer_engine/recipe.h" +#include "transformer_engine/transformer_engine.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type RHTAmaxCalculationFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf, + Result_Type post_rht_amax_buf, + int64_t rht_matrix_random_sign_mask_t, bool produce_regular_amax, + int64_t flatten_axis) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, + "Input must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(input_buf.element_type()) == DType::kBFloat16, + "Input must be of type bfloat16 for RHT Amax calculation"); + + NVTE_CHECK(flatten_axis > 0 && flatten_axis < static_cast(input_buf.dimensions().size()), + "Flatten axis is out of bounds"); + TensorWrapper input_tensor(input_buf.untyped_data(), + std::vector{product(input_buf.dimensions(), 0, flatten_axis), + product(input_buf.dimensions(), flatten_axis, + input_buf.dimensions().size())}, + convert_ffi_datatype_to_te_dtype(input_buf.element_type())); + + float *amax_out = nullptr; + if (produce_regular_amax) { + amax_out = reinterpret_cast(amax_buf->untyped_data()); + NVTE_CHECK(amax_out != nullptr, "Amax output must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(amax_buf->element_type()) == DType::kFloat32, + "Amax output must be of type float32 for RHT Amax calculation"); + NVTE_CHECK(amax_buf->dimensions().size() == 1 && amax_buf->dimensions()[0] == 1, + "Amax output must be a single float for RHT Amax calculation"); + } + + float *post_rht_amax_out = reinterpret_cast(post_rht_amax_buf->untyped_data()); + NVTE_CHECK(post_rht_amax_out != nullptr, + "Post-RHT Amax output must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(post_rht_amax_buf->element_type()) == DType::kFloat32, + "Post-RHT Amax output must be of type float32 for RHT Amax calculation"); + NVTE_CHECK(post_rht_amax_buf->dimensions().size() == 1 && post_rht_amax_buf->dimensions()[0] == 1, + "Post-RHT Amax output must be a single float for RHT Amax calculation"); + + TensorWrapper out_tensor{}; + out_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + out_tensor.set_columnwise_amax(post_rht_amax_out, DType::kFloat32, std::vector{1}); + + // Zero'ing of amaxes is handled by TE common inside nvte_hadamard_transform_amax + nvte_hadamard_transform_amax(input_tensor.data(), out_tensor.data(), + 0, // Regular amax for rowwise does not apply RHT so mask is 0 + rht_matrix_random_sign_mask_t, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RHTAmaxCalculationHandler, RHTAmaxCalculationFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // amax + .Ret() // post_rht_amax + .Attr("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t + .Attr("produce_regular_amax") // produce_regular_amax + .Attr("flatten_axis"), // flatten_axis + FFI_CudaGraph_Traits); + +Error_Type RHTAmaxCalculationInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type amax_buf, Result_Type post_rht_amax_buf, + int64_t rht_matrix_random_sign_mask_t, + bool produce_regular_amax, int64_t flatten_axis) { + return wrapInStreamCapture(std::function(RHTAmaxCalculationFFI), stream, input_buf, amax_buf, + post_rht_amax_buf, rht_matrix_random_sign_mask_t, produce_regular_amax, + flatten_axis); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RHTAmaxCalculationInitializeHandler, RHTAmaxCalculationInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // amax + .Ret() // post_rht_amax + .Attr("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t + .Attr("produce_regular_amax") // produce_regular_amax + .Attr("flatten_axis")); // flatten_axis + +} // namespace jax +} // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index e77c38e99..a0425efda 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -41,6 +41,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E8M0FNU: return DType::kFloat8E8M0; break; + case xla::ffi::DataType::F4E2M1FN: + return DType::kFloat4E2M1; + break; default: auto type_num = static_cast(type); NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 82f062a15..0fc2e8389 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -102,6 +102,8 @@ inline static size_t te_dtype_bytes(const DType& type) { return 1; case DType::kFloat8E8M0: return 1; + case DType::kFloat4E2M1: + return 1; default: NVTE_ERROR("Unsupported DType: ", static_cast(type)); } diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 993ec1377..8a3658a0b 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -51,7 +51,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set scaling factor for quantized tensors if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { - NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1, + "Quantized GEMM requires 4-bit or 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); std::vector scale_shape = {1}; @@ -74,7 +75,8 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, - Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, + Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, @@ -119,6 +121,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out @@ -136,11 +140,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, - JAXX_Collective_Op collective_op) { + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || @@ -192,10 +196,31 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); - - // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); + float one = 1.; + float zero = 0.; + // alpha, beta + float *alpha_ptr = &one, *beta_ptr = &zero; + if (is_nvfp4_scaling(scaling_mode)) { + NVTE_CHECK(alpha.element_count() == 1 && + convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32); + alpha_ptr = reinterpret_cast(alpha.untyped_data()); + NVTE_CHECK(beta.element_count() == 1 && + convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32); + beta_ptr = reinterpret_cast(beta.untyped_data()); + } + + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sm); + if (fuse_bias) config.set_bias_tensor(bias_.data()); + if (fuse_gelu) { + config.set_with_gelu_epilogue(true); + config.set_epilogue_aux_tensor(pre_gelu_.data()); + } + if (collective_op == JAXX_Collective_Op::NONE) { auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); NVTE_CHECK(out_.numel() == output->element_count(), @@ -205,9 +230,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, ", out_shape[1]=", out_shape[1]); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr, + rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, + out_.data() /*D*/, workspace_.data(), config, stream); } else { std::vector buffer_shape{0, 0}; DType buffer_dtype = out_dtype; @@ -268,6 +294,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out @@ -599,9 +627,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // point to swizzled scale_inv data (store on workspace, only used for GEMM). // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = - get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); + get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); auto rhs_sinv_shape_i = - get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); + get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index ee81b5ad7..176115ade 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -26,11 +26,21 @@ std::vector Shape::to_vector() const { return shape; } -std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) { - auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x; - auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y; - auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x; - auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y; +std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, + bool is_colwise) { + auto block_size = BLOCK_SIZE(1, 1); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + block_size = MXFP8_BLOCK_SIZE; + } else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) { + block_size = NVFP4_BLOCK_SIZE; + } else { + NVTE_ERROR("Unsupported scaling_mode = ", static_cast(scaling_mode)); + } + auto block_x = is_colwise ? block_size.y : block_size.x; + auto block_y = is_colwise ? block_size.x : block_size.y; + auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x; + auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y; NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M); NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N); diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index c8fb713d7..07e9aec7e 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -45,6 +45,8 @@ enum class JAXX_Scaling_Mode : int64_t { DELAYED_TENSOR_SCALING = 1, MXFP8_1D_SCALING = 2, CURRENT_TENSOR_SCALING = 3, + NVFP4_1D_SCALING = 4, + NVFP4_2D_SCALING = 5, }; inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { @@ -56,6 +58,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); } +inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: @@ -70,22 +77,32 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; break; + case JAXX_Scaling_Mode::NVFP4_1D_SCALING: + return NVTEScalingMode::NVTE_NVFP4_1D_SCALING; + break; + case JAXX_Scaling_Mode::NVFP4_2D_SCALING: + // TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization + return NVTEScalingMode::NVTE_NVFP4_1D_SCALING; + break; default: NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); break; } } -constexpr struct BlockSize { +struct BLOCK_SIZE { size_t x; size_t y; -} MXFP8_BLOCK_SIZE{1, 32}; -constexpr struct Alignment { - size_t x; - size_t y; -} MXFP8_ALIGNMENT{128, 4}; + constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {} +}; + +constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32}; +constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16}; + +constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4}; -std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, + bool is_colwise); template void hash_combine(int64_t &seed, const T &v, Rest... rest) { diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index f6b1acd43..d740df0e2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -76,6 +76,11 @@ pybind11::dict Registrations() { pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + // Amax + dict["te_rht_amax_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); + return dict; } @@ -106,7 +111,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("kFloat16", DType::kFloat16) .value("kBFloat16", DType::kBFloat16) .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); + .value("kFloat8E5M2", DType::kFloat8E5M2) + .value("kFloat8E8M0", DType::kFloat8E8M0) + .value("kFloat4E2M1", DType::kFloat4E2M1); pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) @@ -165,6 +172,8 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) + .value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING) + .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .export_values(); pybind11::enum_(m, "QuantizeLayout", diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 05260741b..a45a69882 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -5,8 +5,11 @@ ************************************************************************/ #include +#include + #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" #include "xla/ffi/api/c_api.h" @@ -15,7 +18,7 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, + DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, QuantizeLayout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; @@ -30,16 +33,22 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ // this function. We pass a dummy pointer as a workaround. int temp = 0; + bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto scale_shape = std::vector{1}; // Only the pointers will be checked for scale_inv, thus the shapes do not matter if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); - if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, - std::vector{1}); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + if (is_nvfp4) + scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false); + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), scale_dtype, + scale_shape); } } @@ -49,13 +58,16 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter - if (is_fp8_dtype(out_dtype)) { - output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, - std::vector{1}); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + if (is_nvfp4) + scale_shape = + get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), scale_dtype, + scale_shape); } } - if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -72,17 +84,20 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Buffer_Type amax_buf, Result_Type output_buf, - Result_Type output_trans_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, - bool is_dbias, int64_t flatten_axis) { + Buffer_Type amax_buf, Buffer_Type sr_rng_state, + Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis, + bool stochastic_rounding, bool use_rht) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); - NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization."); + NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype), + "Output datatype must be FP8 or FP4 for quantization."); auto *input = input_buf.untyped_data(); @@ -112,41 +127,106 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; + + NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); + NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); if (quantize_layout == QuantizeLayout::ROWWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); - if (is_fp8_dtype(out_dtype)) { - if (is_tensor_scaling) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf.untyped_data()); - float *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax == updated_amax && amax != nullptr, - "amax must be provided for delayed tensor scaling"); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{1}); - } else { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), - product(scale_inv_buf->dimensions(), flatten_axis, - scale_inv_buf->dimensions().size())}); - } + if (is_tensor_scaling) { + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); } } + if (is_nvfp4) { + float *amax = reinterpret_cast(amax_buf.untyped_data()); + NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4"); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } + + QuantizationConfigWrapper quant_config{}; + if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) { + quant_config.set_nvfp4_2d_quantization(true); + } + + // Stochastic rounding + quant_config.set_stochastic_rounding(stochastic_rounding); + TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector{2}, + DType::kInt64); + if (stochastic_rounding) { + NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t), + "rng_state must be of type int64[2]"); + NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR"); + quant_config.set_rng_state(sr_rng_state_tensor.data()); + } + if (quantize_layout == QuantizeLayout::COLWISE || quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + if (is_nvfp4 && use_rht) { + if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + // Do regular rowwise quantization without RHT + nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); + } + + TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode)); + + // nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper + out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape); + auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis; + out_transpose.set_rowwise_scale_inv( + colwise_scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), + std::vector{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis), + product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis, + colwise_scale_inv_buf->dimensions().size())}); + + float *post_rht_amax = reinterpret_cast(post_rht_amax_buf.untyped_data()); + NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4"); + out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector{1}); + + bool const eligible_for_rht_cast_fusion = + input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0; + NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met"); + + NVTE_CHECK( + convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16, + "RHT matrix must be bf16"); + NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 && + rht_matrix_buf.dimensions()[1] == 16, + "RHT matrix must be 16x16"); + TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector{16, 16}, + DType::kBFloat16); + + nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(), + rht_matrix_tensor.data(), quant_config, + stream); + + return ffi_with_cuda_error_check(); + } + + bool const is_colwise_transposed = + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; + auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf; @@ -156,26 +236,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); } else { + auto colwise_flatten_axis = flatten_axis; + if (is_colwise_transposed) { + // convert flatten_axis from N layout to T layout + colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis; + } output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{ - product(tmp_buf->dimensions(), 0, flatten_axis), - product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + product(tmp_buf->dimensions(), 0, colwise_flatten_axis), + product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())}); } } - if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - output_tensor.set_amax(nullptr, DType::kFloat32, std::vector{1}); - } - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); if (is_dbias) { + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING, + "DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot " + "take quant_config as input."); nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), workspace_tensor.data(), stream); } else { - nvte_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); } return ffi_with_cuda_error_check(); } @@ -186,6 +270,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Arg() // input .Arg() // scale .Arg() // amax + .Arg() // sr_rng_state + .Arg() // colwise amax + .Arg() // rht matrix .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -196,7 +283,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Attr("scaling_mode") .Attr("q_layout") .Attr("is_dbias") - .Attr("flatten_axis"), + .Attr("flatten_axis") + .Attr("stochastic_rounding") + .Attr("use_rht"), FFI_CudaGraph_Traits); Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -346,7 +435,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty sinv_size = 1; } else { const bool is_colwise = false; - auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise); out_i.set_rowwise_scale_inv(static_cast(sinv_ptr), sinv_dtype, sinv_shape_i); sinv_size = product(sinv_shape_i); } @@ -365,7 +454,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty colwise_sinv_size = 1; } else { const bool is_colwise = true; - auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise); out_i.set_columnwise_scale_inv(static_cast(colwise_sinv_ptr), sinv_dtype, sinv_shape_i); colwise_sinv_size = product(sinv_shape_i); diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 28525a22a..44c73a5b1 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex -from .cpp_extensions.quantization import AmaxScope +from .cpp_extensions.amax import AmaxScope from .quantize import ( ScaledTensorFactory, ScalingMode, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 76865f7c1..c54ecb236 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -15,7 +15,6 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from transformer_engine.common import recipe from ..dense import dense @@ -35,10 +34,9 @@ from ..quantize import ( QuantizerFactory, get_quantize_config, - QuantizeMeta, QuantizeMetaSet, - ScalingMode, TensorSource, + get_quantize_config_with_recipe, ) PRNGKey = Any @@ -353,40 +351,32 @@ def generate_quantizer_set( Generate a set of FP8 meta for a GEMM. """ - def generate_quantize_meta(quantizer_name: str): - collection_name = ( - variable_collection - if variable_collection is not None - else get_quantize_config().COLLECTION_NAME - ) - scale = self.variable( - collection_name, - f"{quantizer_name}{postfix}_scale", - jnp.ones, - (1,), - jnp.float32, - ).value - amax_history = self.variable( - collection_name, - f"{quantizer_name}{postfix}_amax_history", - jnp.zeros, - (get_quantize_config().AMAX_HISTORY_LEN,), - jnp.float32, - ).value - return QuantizeMeta(scale=scale, amax_history=amax_history) - - if get_quantize_config().get_scaling_mode( - TensorSource.X - ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling): - x_meta = generate_quantize_meta("x") - kernel_meta = generate_quantize_meta("kernel") - grad_meta = generate_quantize_meta("grad") - quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) - kwargs = {"quantize_meta_set": quantize_meta_set} + collection_name = ( + variable_collection + if variable_collection is not None + else get_quantize_config().COLLECTION_NAME + ) + + if fp8_recipe is None: + quantize_config = get_quantize_config() else: - kwargs = {} + quantize_config = get_quantize_config_with_recipe(fp8_recipe) - quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) + x_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.X, "x" + ) + kernel_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.KERNEL, "kernel" + ) + grad_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.DGRAD, "grad" + ) + + quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) + + quantizer_set = QuantizerFactory.create_set( + fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set + ) return quantizer_set diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 136f43df4..705c74232 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -16,7 +16,7 @@ import jax.numpy as jnp from . import cpp_extensions as tex -from .cpp_extensions.quantization import AmaxScope +from .cpp_extensions.amax import AmaxScope from .quantize import ( QuantizerSet, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index c43430cf3..100848fdd 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,7 +21,7 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex -from .cpp_extensions.quantization import AmaxScope +from .cpp_extensions.amax import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index 11f692917..9616965c7 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -14,5 +14,6 @@ from .dequantizer import * from .scaling_modes import * from .metadata import * +from .hadamard import * from .helper import * from .device_utils import * diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 9d46c3c30..b4da6f3be 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -15,6 +15,8 @@ import jax.numpy as jnp from .scaling_modes import ScalingMode +from .hadamard import apply_rht, should_use_rht + __all__ = ["ScalingModeToDequantizerMap"] @@ -119,7 +121,7 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte 0 < flatten_axis < len(data_shape) ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" scale_shape = scaling_mode.get_scale_shape( - data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) data = data.reshape( @@ -161,10 +163,99 @@ def dequantize(scaled_tensor): ) +class NVFP4Dequantizer(Dequantizer): + """NVFP4 Dequantizer Class. + + This class provides static methods for dequantizing tensors that have been + quantized using NVFP4 scaling modes. + """ + + @staticmethod + def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis): + """Dequantize a tensor using block scaling. + + Args: + data: The quantized tensor data + scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D + + Returns: + The dequantized tensor + """ + + DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32) + tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) + + data = data.astype(jnp.float32) + scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv + data_layout = "T" if is_colwise else "N" + + data_shape = data.shape + flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + scale_shape = scaling_mode.get_scale_shape( + data_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + # expect the flatten_axis wrt the N layout + flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + + data = data.reshape( + *data_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *data_shape[flatten_axis:-1], + scale_shape[-1], + int(data_shape[-1] / scale_shape[-1]), + ) + + scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1)) + out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape) + + # Apply inverse of RHT if needed + use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise) + if use_rht: + out = apply_rht(out, inverse=True) + + return out + + @staticmethod + def dequantize(scaled_tensor): + """Dequantize a tensor using block scaling. + + Args: + scaled_tensor: The quantized tensor to dequantize + + Returns: + The dequantized tensor + """ + return NVFP4Dequantizer._dequantize_func( + scaled_tensor.data, + scaled_tensor.scale_inv, + scaled_tensor.amax, + scaled_tensor.dq_dtype, + scaled_tensor.scaling_mode, + scaled_tensor.is_colwise, + scaled_tensor.flatten_axis, + ) + + ScalingModeToDequantizerMap = { ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer, + ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer, ScalingMode.NO_SCALING: NoopDequantizer, } @@ -210,13 +301,13 @@ def _grouped_dequantize(grouped_scaled_tensor): ) padded_scale_shape_i = scaling_mode.get_scale_shape( data_shape_i, - grouped_scaled_tensor.is_colwise, + is_colwise=grouped_scaled_tensor.is_colwise, is_padded=True, flatten_axis=flatten_axis, ) unpadded_scale_shape_i = scaling_mode.get_scale_shape( data_shape_i, - grouped_scaled_tensor.is_colwise, + is_colwise=grouped_scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis, ) diff --git a/transformer_engine/jax/quantize/hadamard.py b/transformer_engine/jax/quantize/hadamard.py new file mode 100644 index 000000000..c0b74ef75 --- /dev/null +++ b/transformer_engine/jax/quantize/hadamard.py @@ -0,0 +1,72 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Randomized Hadamard Transform (RHT) utilities for JAX.""" +import jax.numpy as jnp + +from .scaling_modes import ScalingMode + + +def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool: + """Determine if RHT (Randomized Hadamard Transform) should be used. + + Args: + scaling_mode: The scaling mode of the tensor. + is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided. + q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided. + + Returns: + bool: True if RHT should be used, False otherwise. + """ + # Delayed import to avoid circular dependencies + from .quantizer import QuantizeLayout + + assert (is_colwise is None) != ( + q_layout is None + ), "Exactly one of is_colwise or q_layout must be provided." + + if q_layout is not None: + is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE} + + return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise + + +def get_wgrad_sign_vector() -> list[int]: + """Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization.""" + return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1] + + +def get_sign_from_vector(vector: list[int]) -> int: + """Convert a sign vector to a bitmask integer.""" + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask + + +def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray: + """Apply the Randomized Hadamard Transform (RHT) to the input tensor.""" + h = get_rht_matrix() + block_size = 16 + if inverse: + h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16) + # TODO(jberchtold): These reshapes will break partitioning, fixme + return (x.reshape(-1, block_size) @ h).reshape(x.shape) + + +def get_rht_matrix() -> jnp.ndarray: + """Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization. + + Returns: + A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask. + """ + import scipy + + block_size = 16 + h = jnp.array(scipy.linalg.hadamard(block_size)) + + # Apply the random sign mask + s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32) + h = jnp.diag(s) @ h + + return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 67f0a68c6..70611cbea 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -11,9 +11,12 @@ from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Dict, Union, Sequence, Type -from functools import reduce +from typing import Optional, Tuple, Dict, Union, Sequence, Type, List +from functools import reduce, lru_cache import operator +from importlib.metadata import version as get_pkg_version +import warnings +from packaging.version import Version as PkgVersion import jax import jax.numpy as jnp @@ -21,18 +24,27 @@ from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine.common import recipe -from transformer_engine.jax.sharding import global_shard_guard, MeshResource - +from transformer_engine.jax.sharding import ( + global_shard_guard, + MeshResource, + num_of_devices, + get_all_mesh_axes, + with_sharding_constraint, +) + +from .metadata import QuantizeMeta from .scaling_modes import ScalingMode -from .. import cpp_extensions as tex from .device_utils import get_device_compute_capability __all__ = [ "get_quantize_config", + "get_quantize_config_with_recipe", "fp8_autocast", "is_fp8_available", + "is_scaling_mode_supported", + "get_supported_scaling_modes", + "get_supported_quantization_recipes", "update_collections", - "get_delayed_scaling", "apply_padding_to_scale_inv", "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", @@ -41,11 +53,23 @@ _is_fp8_available = None _reason_for_no_fp8 = "" +_is_scaling_mode_supported = None +_reason_for_no_scaling_mode = "" Collection = Union[Dict, FrozenDict] NVTE_FP8_COLLECTION_NAME = "fp8_metas" +@lru_cache(maxsize=None) +def _jax_version_meet_requirement(version: str): + """ + Helper function checking if required JAX version is available + """ + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -55,8 +79,6 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ - if gpu_arch >= 90: # hopper and above - return True, "" if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: @@ -75,20 +97,31 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ - if gpu_arch >= 100: # blackwell and above - return True, "" if gpu_arch < 99: # pre-blackwell return False, "Device compute capability 9.9 or higher required for MXFP8 execution." if get_cublasLt_version() < 120800: return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution." - if get_cuda_version() < 12010: + if get_cuda_version() < 12080: return False, "Cuda version 12.8 or higher required for MXFP8 execution." - if not tex.jax_version_meet_requirement("0.5.3"): + if not _jax_version_meet_requirement("0.5.3"): return False, "Jax version 0.5.3 or higher required for MXFP8 execution." return True, "" -def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: +def _check_fp4_support(gpu_arch) -> Tuple[bool, str]: + """Check if FP4 is supported for the given GPU architecture.""" + if gpu_arch < 100: # pre-blackwell + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + if get_cublasLt_version() < 120800: + return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution." + if get_cuda_version() < 12080: + return False, "Cuda version 12.8 or higher required for NVFP4 execution." + if not _jax_version_meet_requirement("0.5.3"): + return False, "Jax version 0.5.3 or higher required for NVFP4 execution." + return True, "" + + +def _check_scaling_support(scaling_mode: ScalingMode, gpu_id: int) -> Tuple[bool, str]: """Check if FP8 is supported for the given scaling mode and GPU. Args: @@ -101,9 +134,35 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: gpu_arch = get_device_compute_capability(gpu_id) if scaling_mode.is_tensor_scaling(): return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + if scaling_mode.is_mxfp8_scaling: return _check_block_scaling_fp8_support(gpu_arch) - return (False, "Unsupported scaling_mode!") + if scaling_mode.is_nvfp4_scaling: + return _check_fp4_support(gpu_arch) + return (True, "") # NO_SCALING is always supported + + +def is_scaling_mode_supported( + scaling_mode=ScalingMode.NO_SCALING, + gpu_id=None, +) -> Tuple[bool, str]: + """Check if the given scaling mode is available for the given GPU.""" + if gpu_id is not None: + return _check_scaling_support(scaling_mode, gpu_id) + + global _is_scaling_mode_supported, _reason_for_no_scaling_mode + if _is_scaling_mode_supported is None: + _is_scaling_mode_supported = {} + _reason_for_no_scaling_mode = {} + if scaling_mode not in _is_scaling_mode_supported: + _is_scaling_mode_supported[scaling_mode] = True + _reason_for_no_scaling_mode[scaling_mode] = "" + for local_gpu_id in range(len(jax.local_devices())): + ret, msg = _check_scaling_support(scaling_mode, local_gpu_id) + if ret is False: + _is_scaling_mode_supported[scaling_mode] = ret + _reason_for_no_scaling_mode[scaling_mode] = msg + return ret, msg + return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode] def is_fp8_available( @@ -119,26 +178,36 @@ def is_fp8_available( Returns: A tuple of (bool, str) indicating availability and any error message """ - if gpu_id is not None: - return _check_fp8_support(scaling_mode, gpu_id) - - global _is_fp8_available, _reason_for_no_fp8 - if _is_fp8_available is None: - _is_fp8_available = {} - _reason_for_no_fp8 = {} - - if scaling_mode not in _is_fp8_available: - _is_fp8_available[scaling_mode] = True - _reason_for_no_fp8[scaling_mode] = "" - # JAX doesn't provide the local GPU id. - for local_gpu_id in range(len(jax.local_devices())): - ret, msg = _check_fp8_support(scaling_mode, local_gpu_id) - if ret is False: - _is_fp8_available[scaling_mode] = ret - _reason_for_no_fp8[scaling_mode] = msg - return ret, msg - - return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode] + warnings.warn( + "is_fp8_available is deprecated. Use is_scaling_mode_supported instead.", DeprecationWarning + ) + return is_scaling_mode_supported(scaling_mode=scaling_mode, gpu_id=gpu_id) + + +# TODO(Phuong): make the infrastruture to support NO_SCALING +def get_supported_scaling_modes() -> List[ScalingMode]: + """Get all supported quantization scaling modes.""" + return [ + scaling_mode + for scaling_mode in ScalingMode + if is_scaling_mode_supported(scaling_mode=scaling_mode)[0] + and scaling_mode != ScalingMode.NO_SCALING + ] + + +def get_supported_quantization_recipes() -> List[recipe.Recipe]: + """Get all supported quantization recipes.""" + # We don't support all the recipes TE/Common supports yet + # return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()] + all_recipes = [ + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), + recipe.MXFP8BlockScaling(), + recipe.NVFP4BlockScaling(), + ] + return [ + recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0] + ] def _format2dtypes(format_: recipe.Format): @@ -156,6 +225,8 @@ def _format2dtypes(format_: recipe.Format): return jnp.float8_e5m2, jnp.float8_e5m2 if format_ == recipe.Format.HYBRID: return jnp.float8_e4m3fn, jnp.float8_e5m2 + if format_ == recipe.Format.E2M1: + return jnp.float4_e2m1fn, jnp.float4_e2m1fn return jnp.bfloat16, jnp.bfloat16 @@ -193,7 +264,6 @@ class BaseQuantizeConfig(ABC): INITIALIZED: Whether the config has been initialized MARGIN: Margin value for quantization COLLECTION_NAME: Name of the collection for quantization metadata - FP8_FORMAT: FP8 format to use FWD_DTYPE: Forward pass data type BWD_DTYPE: Backward pass data type FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass @@ -207,28 +277,26 @@ class BaseQuantizeConfig(ABC): INITIALIZED = False MARGIN: float = 0.0 COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME - FP8_FORMAT: recipe.Format = recipe.Format.HYBRID - FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] - BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] + FWD_DTYPE: DType = None + BWD_DTYPE: DType = None FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False # DelayedScaling + # TODO(Phuong): move these two into DelayedScalingQuantizeConfig AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: - """Initialize the quantization configuration. + """Initialize the quantization configuration from a given recipe. Args: fp8_recipe: The FP8 recipe to use for initialization """ self.INITIALIZED = True - self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 - self.FP8_FORMAT = fp8_recipe.fp8_format - self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT) + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp8_format) def is_fp8_enabled(self) -> bool: """Check if FP8 quantization is enabled. @@ -249,6 +317,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: The scaling mode for the specified usage type. """ + @abstractmethod + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + def is_supported(self) -> tuple[bool, str]: """Check if this QuantizeConfig class is supported on the available devices. @@ -261,7 +350,7 @@ def is_supported(self) -> tuple[bool, str]: kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: - is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) + is_supported, reason = is_scaling_mode_supported(scaling_mode=scaling_mode) if not is_supported: return is_supported, reason return True, None @@ -281,6 +370,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.NO_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + class DelayedScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for delayed scaling FP8 recipe. @@ -299,6 +409,7 @@ def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: AssertionError: If recipe parameters are not supported """ super().initialize_from_recipe(fp8_recipe) + self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 assert fp8_recipe.amax_compute_algo in [ "max", @@ -323,6 +434,41 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.DELAYED_TENSOR_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + scale = module.variable( + collection_name, + f"{quantizer_name}{postfix}_scale", + jnp.ones, + (1,), + jnp.float32, + ).value + amax_history = module.variable( + collection_name, + f"{quantizer_name}{postfix}_amax_history", + jnp.zeros, + (self.AMAX_HISTORY_LEN,), + jnp.float32, + ).value + return QuantizeMeta(scale=scale, amax_history=amax_history) + class CurrentScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for current scaling FP8 recipe. @@ -344,6 +490,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.CURRENT_TENSOR_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + class BlockScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for block scaling FP8 recipe. @@ -365,6 +532,91 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.MXFP8_1D_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + + +class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): + """Configuration class for NVFP4 scaling recipe. + + This class provides specific initialization and finalization for NVFP4 scaling quantization mode. + """ + + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize block scaling FP8 configuration. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + """ + self.INITIALIZED = True + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format) + self.AMAX_HISTORY_LEN = 0 + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + if tensor_source == TensorSource.KERNEL: + return ScalingMode.NVFP4_2D_SCALING + # for x and grad + return ScalingMode.NVFP4_1D_SCALING + + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + if tensor_source != TensorSource.DGRAD: + # Only DGRAD uses stochastic rounding + return QuantizeMeta() + + # TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it. + sr_jax_rng = module.make_rng("sr_rng") + # Get a unique key for this quantizer + sr_jax_rng = jax.jit(jax.random.fold_in)( + sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max + ) + + # Generate 4 random uint32 values from the JAX PRNG key + sr_jax_rng_state = jax.random.randint( + sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 + ).view(jnp.uint32) + sr_jax_rng_state = with_sharding_constraint( + sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None) + ) + + return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state) + _QUANTIZE_CONFIG = NoOpQuantizeConfig() @@ -377,7 +629,7 @@ def get_quantize_config(): def get_quantize_config_class( fp8_recipe: recipe.Recipe, ) -> Type[BaseQuantizeConfig]: - """Get the quantization configuration based on the FP8 recipe. + """Get the quantization configuration class based on the FP8 recipe. Args: fp8_recipe: The FP8 recipe to use for initialization @@ -390,9 +642,18 @@ def get_quantize_config_class( return BlockScalingQuantizeConfig if isinstance(fp8_recipe, recipe.Float8CurrentScaling): return CurrentScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.NVFP4BlockScaling): + return NVFP4ScalingQuantizeConfig raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") +def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe): + """Get the quantization configuration object based on the FP8 recipe.""" + config = get_quantize_config_class(fp8_recipe)() + config.initialize_from_recipe(fp8_recipe) + return config + + @contextmanager def fp8_autocast( enabled: bool = False, @@ -457,31 +718,6 @@ def fp8_autocast( _QUANTIZE_CONFIG = old_quantize_config -def get_delayed_scaling(): - r""" - Obtain an instance of DelayedScaling which is set via fp8_autocast. - - .. note:: - We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len` - , and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in - recipe.DelayedScaling would be returned as the default values. - - Returns - ------- - delay_scaling : DelayedScaling - an instance of DelayedScaling which is set via fp8_autocast. - """ - amax_compute_algo = ( - "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" - ) - return recipe.DelayedScaling( - margin=int(get_quantize_config().MARGIN), - fp8_format=get_quantize_config().FP8_FORMAT, - amax_history_len=get_quantize_config().AMAX_HISTORY_LEN, - amax_compute_algo=amax_compute_algo, - ) - - def update_collections(new: Collection, original: Collection) -> Collection: r"""Update collections with new values while preserving original structure. diff --git a/transformer_engine/jax/quantize/metadata.py b/transformer_engine/jax/quantize/metadata.py index 637450216..11a349ed7 100644 --- a/transformer_engine/jax/quantize/metadata.py +++ b/transformer_engine/jax/quantize/metadata.py @@ -9,23 +9,29 @@ scale factors and amax history for different tensor types. """ from dataclasses import dataclass -import jax.numpy as jnp __all__ = ["QuantizeMeta", "QuantizeMetaSet"] -@dataclass class QuantizeMeta: """Metadata for quantization parameters. - Attributes: + For Delayed Scaling recipe: scale: The scaling factor for quantization amax_history: History of maximum absolute values + + For NVFP4 recipe with Stochastic Rounding: + sr_rng_state: The state of the stochastic rounding RNG + """ - scale: jnp.ndarray - amax_history: jnp.ndarray + def __init__(self, **kwargs): + self._kwargs = kwargs + + def get_kwargs_dictionary(self): + """Get the metadata as a dictionary.""" + return self._kwargs @dataclass diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 306603bbe..7198014f2 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -19,6 +19,7 @@ from transformer_engine.common import recipe from .scaling_modes import ScalingMode +from .hadamard import apply_rht, should_use_rht from .tensor import ( ScaledTensor, ScaledTensor1x, @@ -28,7 +29,7 @@ ) from .helper import ( get_quantize_config, - get_quantize_config_class, + get_quantize_config_with_recipe, AmaxComputeAlgo, TensorSource, ) @@ -66,6 +67,7 @@ def compute_scale_from_amax( sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) + assert sf.shape == (1,) return sf @@ -155,7 +157,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> """ def quantize( - self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs + self, x, is_rowwise=None, is_colwise=None, dq_dtype=None, flatten_axis=-1, **kwargs ) -> ScaledTensor: """Quantize a tensor using the internal _quantize_func(). @@ -170,6 +172,18 @@ def quantize( A ScaledTensor1x or ScaledTensor2x containing the quantized data """ del kwargs + + is_rowwise = ( + is_rowwise + if is_rowwise is not None + else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) + ) + is_colwise = ( + is_colwise + if is_colwise is not None + else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) + ) + if (is_rowwise and is_colwise) or self.is_2x2x(): rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( @@ -380,6 +394,7 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + # Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization. self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, @@ -494,7 +509,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> dq_dtype = dq_dtype if dq_dtype is not None else x.dtype x_shape = x.shape scale_shape = self.scaling_mode.get_scale_shape( - x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + x_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) scale_dtype = self.scaling_mode.get_scale_dtype() x = x.reshape( @@ -563,6 +578,221 @@ def _e8m0_to_dtype(self, x, dtype): return new_x.astype(dtype) +@register_pytree_node_class +@dataclass +class NVFP4Quantizer(Quantizer): + """Quantizer implementation using current scaling. + + This quantizer uses current scaling mode with float32 scales + + Attributes: + scaling_mode: Set to NVFP4_1D_SCALING or NVFP4_2D_SCALING + q_layout: Quantization axis + data_layout: Data layout string (default: "NT") + stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled. + """ + + scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE + data_layout: str = "NT" + stochastic_rounding_rng_state: Optional[jnp.ndarray] = None + + def __post_init__(self): + assert ( + self.q_dtype == jnp.float4_e2m1fn + ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn" + assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes" + + def _apply_stochastic_rounding(self, x): + assert ( + self.stochastic_rounding_rng_state is not None + ), "Stochastic rounding RNG state is not initialized" + assert self.stochastic_rounding_rng_state.shape == ( + 4, + ), "Stochastic rounding RNG state must be of shape (4,)" + assert ( + self.stochastic_rounding_rng_state.dtype == jnp.uint32 + ), "Stochastic rounding RNG state must be of dtype uint32" + + # Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s + key_bits = jnp.array( + [ + self.stochastic_rounding_rng_state[0], + self.stochastic_rounding_rng_state[1], + ], + dtype=jnp.uint32, + ) + key = jax.random.wrap_key_data(key_bits) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2]) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3]) + + abs_x = jnp.abs(x) + sign_x = jnp.sign(x) + + floor = ( + (abs_x >= 0.5) * 0.5 + + (abs_x >= 1) * 0.5 + + (abs_x >= 2) + + (abs_x >= 3) + + (abs_x >= 4) + + (abs_x >= 6) * 2 + ) + ceil = ( + 0.5 + + (abs_x > 0.5) * 0.5 + + (abs_x > 1) * 1 + + (abs_x > 2) + + (abs_x > 3) + + (abs_x > 4) * 2 + ) + frac = (abs_x - floor) / (ceil - floor) + + rand = jax.random.uniform(key, abs_x.shape) + return sign_x * jnp.where(frac >= rand, ceil, floor) + + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: + """Quantize function helper for block scaling FP8. + + Args: + x: Input tensor to quantize + is_colwise: Whether to use column-wise quantization + dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor + + Returns: + A ScaledTensor1x containing the quantized data + """ + # TODO(Phuong): use quantize_func from JAX + if flatten_axis < 0: + flatten_axis = x.ndim + flatten_axis + assert ( + 0 <= flatten_axis < x.ndim + ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}" + + should_apply_rht = self.scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise + + global_amax = None + if isinstance(x, NoScaleTensor): + global_amax = ( + x.amax if not should_apply_rht else None + ) # RHT changes the amax so don't use precalculated amax for colwise 1D nvfp4 quantization with RHT + x = x.data + + # Transpose if required + rowwise_flatten_axis = flatten_axis + data_layout = self.data_layout[0] + if is_colwise: + x = jnp.transpose(x, (*range(flatten_axis, x.ndim), *range(flatten_axis))) + data_layout = self.data_layout[1] + # convert flatten_axis from N layout to T layout + flatten_axis = x.ndim - flatten_axis + x_shape = x.shape + + if should_use_rht(self.scaling_mode, is_colwise=is_colwise): + # We only apply RHT for 1D colwise nvfp4 + x = apply_rht(x) + + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + scale_shape = self.scaling_mode.get_scale_shape( + x_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=rowwise_flatten_axis, + ) + scale_dtype = self.scaling_mode.get_scale_dtype() + x = x.reshape( + *x_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *x_shape[flatten_axis:-1], + scale_shape[-1], + int(x_shape[-1] / scale_shape[-1]), + ) + + # Dtype max constants + DATA_DTYPE_MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(scale_dtype).max.astype(jnp.float32) + + # Level 1: Current Tensor Scaling + global_amax = ( + global_amax + if global_amax is not None + else jnp.max(jnp.abs(x)).reshape((1,)).astype(jnp.float32) + ) + tensor_scale = DATA_DTYPE_MAX * SCALE_DTYPE_MAX / global_amax + tensor_scale = jnp.minimum( + tensor_scale, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32) + ) + tensor_scale = jnp.where( + tensor_scale == jnp.array(0.0, dtype=jnp.float32), + jnp.array(1.0, dtype=jnp.float32), + tensor_scale, + ) + tensor_scale_inv = 1.0 / tensor_scale + + # Level 2: Block Scaling + block_amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True).astype( + jnp.float32 + ) + block_scale_inv = jnp.divide(block_amax, DATA_DTYPE_MAX) + block_scale_inv = block_scale_inv * tensor_scale + block_scale_inv = jnp.minimum( + block_scale_inv, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32) + ) + block_scale_inv = jnp.clip(block_scale_inv, -SCALE_DTYPE_MAX, SCALE_DTYPE_MAX) + # We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate. + block_scale_inv = block_scale_inv.astype(scale_dtype) + # Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype. + assert scale_dtype == jnp.float8_e4m3fn, "Only float8_e4m3fn is supported for scale_dtype" + block_scale_inv = jax.lax.reduce_precision(block_scale_inv, 4, 3) + block_scale = jnp.minimum( + jnp.divide(1.0, block_scale_inv.astype(jnp.float32) * tensor_scale_inv), + jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32), + ) + + # Apply scaling + scaled_x = x.astype(jnp.float32) * block_scale + if self.stochastic_rounding_rng_state is not None: + scaled_x = self._apply_stochastic_rounding(scaled_x) + clipped_x = jnp.clip(scaled_x, -DATA_DTYPE_MAX, DATA_DTYPE_MAX) + + # Cast to the right dtype + quantized_data = clipped_x.reshape(x_shape).astype(self.q_dtype) + block_scale_inv = block_scale_inv.reshape(scale_shape).astype(scale_dtype) + + # In the 2D scaling mode, the scale shape is 2D but it needs to be broadcasted to 1D for GEMM. + # TODO(Phuong): expose this broadcast_2d_scale_shape_to_1d option to the + # quantizer.quantize() API + broadcasted_1d_scale_shape = self.scaling_mode.get_scale_shape( + x_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=rowwise_flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + + # Broadcast and tile x to match the target shape + def repeat_to_shape(x, target_shape): + x_shape = x.shape + reps = [int(t // s) for s, t in zip(x_shape, target_shape)] + return jnp.tile(x, reps) + + block_scale_inv = repeat_to_shape(block_scale_inv, broadcasted_1d_scale_shape) + + return ScaledTensorFactory.create_1x( + data=quantized_data, + data_layout=data_layout, + is_colwise=is_colwise, + scale_inv=block_scale_inv, + amax=global_amax, + scaling_mode=self.scaling_mode, + dq_dtype=dq_dtype, + flatten_axis=rowwise_flatten_axis, + ) + + @register_pytree_node_class @dataclass class QuantizerSet: @@ -801,6 +1031,8 @@ class QuantizerFactory: ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.NVFP4_1D_SCALING: NVFP4Quantizer, + ScalingMode.NVFP4_2D_SCALING: NVFP4Quantizer, } @staticmethod @@ -826,7 +1058,6 @@ def create( Returns: A single quantizer or tuple of quantizers """ - # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" if n_groups: if n_quantizers != 1: @@ -887,18 +1118,9 @@ def _create_set( if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") - args_x = { - "scale": quantize_meta_set.x.scale, - "amax_history": quantize_meta_set.x.amax_history, - } - args_kernel = { - "scale": quantize_meta_set.kernel.scale, - "amax_history": quantize_meta_set.kernel.amax_history, - } - args_grad = { - "scale": quantize_meta_set.grad.scale, - "amax_history": quantize_meta_set.grad.amax_history, - } + args_x = quantize_meta_set.x.get_kwargs_dictionary() + args_kernel = quantize_meta_set.kernel.get_kwargs_dictionary() + args_grad = quantize_meta_set.grad.get_kwargs_dictionary() else: args_x = args_kernel = args_grad = {} @@ -919,6 +1141,7 @@ def create_set( bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, n_groups: int = None, + # TODO(jberchtold): rename fp8_recipe to quantization_recipe fp8_recipe: Optional[recipe.Recipe] = None, **kwargs, ) -> tuple[Union[tuple[Quantizer], None]]: @@ -946,21 +1169,24 @@ def create_set( ) if fp8_recipe is not None: - quantize_config = get_quantize_config_class(fp8_recipe)() + quantize_config = get_quantize_config_with_recipe(fp8_recipe) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) - elif scaling_mode is not None: - x_scaling_mode = scaling_mode - kernel_scaling_mode = scaling_mode - grad_scaling_mode = scaling_mode + fwd_dtype = quantize_config.FWD_DTYPE + bwd_dtype = quantize_config.BWD_DTYPE else: - x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) - kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) - grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + if scaling_mode is not None: + x_scaling_mode = scaling_mode + kernel_scaling_mode = scaling_mode + grad_scaling_mode = scaling_mode + else: + x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) + kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) - fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE - bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE + fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE + bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: # TODO(Jeremy): check x, kernel, grad separately for 2x if x_scaling_mode.is_1d_block_scaling(): diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index b7828e931..d490e0275 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -100,10 +100,19 @@ def get_scale_dtype(self) -> jnp.dtype: The data type used for scale tensors """ + @abstractmethod + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + @abstractmethod def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, @@ -112,6 +121,7 @@ def get_scale_shape( Args: data_shape: The shape of the tensor being quantized + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -156,13 +166,15 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode @@ -183,12 +195,22 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float32 + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return "NN" + def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = True, ) -> Tuple[int, ...]: """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. @@ -201,7 +223,14 @@ def get_scale_shape( Returns: The shape for scale tensors - (1,) """ - del data_shape, is_colwise, is_padded, flatten_axis + del ( + data_shape, + data_layout, + is_colwise, + is_padded, + flatten_axis, + broadcast_2d_scale_shape_to_1d, + ) return (0,) @lru_cache(maxsize=4) @@ -239,18 +268,20 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ - del flatten_axis + del flatten_axis, broadcast_2d_scale_shape_to_1d input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -270,25 +301,37 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float32 + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return "NT" + def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = True, ) -> Tuple[int, ...]: """Get the shape for scale tensors in delayed scaling. Args: data_shape: The shape of the tensor being scaled + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True. Returns: The shape for scale tensors - (1,) """ - del is_colwise + del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d if np.prod(data_shape) == 0: return (0,) return (1,) @@ -333,6 +376,7 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -340,11 +384,12 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ - del flatten_axis + del flatten_axis, broadcast_2d_scale_shape_to_1d input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) scale_var = BATCHING + unique_var + "_scale_inv" return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) @@ -368,14 +413,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): _block_alignment: Alignment requirements for blocks """ - def __init__(self, block_dims: Tuple[int]): + def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str): """Initialize block scaling mode implementation. Args: block_dims: Dimensions of the scaling blocks + scale_dtype: Data type of the scale tensor + data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. """ self._block_dims = block_dims + self._scale_dtype = scale_dtype self._block_alignment = (128, 4) + self._data_layout = data_layout def get_scale_dtype(self) -> jnp.dtype: """Get the data type for scale tensors in block scaling. @@ -383,7 +432,15 @@ def get_scale_dtype(self) -> jnp.dtype: Returns: The data type used for scale tensors (float8_e8m0fnu) """ - return jnp.float8_e8m0fnu + return self._scale_dtype + + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return self._data_layout def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim): """Remove excess padding from the scale shape and return the shape with respect to the original data shape.""" @@ -411,23 +468,51 @@ def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_ def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = False, ) -> Tuple[int, ...]: """Get the shape for scale tensors in block scaling. Args: data_shape: The shape of the tensor being quantized + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True. Returns: The shape for scale tensors """ + flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape) + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + block_alignment = self._block_alignment if is_padded else (1, 1) + if is_colwise: + assert data_layout == self._data_layout[1], ( + f"Data layout must match colwise layout, received {data_layout} but expected" + f" {self._data_layout[1]}" + ) + else: + assert data_layout == self._data_layout[0], ( + f"Data layout must match rowwise layout, received {data_layout} but expected" + f" {self._data_layout[0]}" + ) + + if is_colwise and self._data_layout[1] == "T": + # TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value + is_colwise = False # now rowwise in T is colwise in N + if flatten_axis < 0: + flatten_axis = len(data_shape) + flatten_axis + # flatten_axis is given wrt N layout, convert to T layout + flatten_axis = len(data_shape) - flatten_axis + if is_colwise: block_y, block_x = self._block_dims alignment_y, alignment_x = block_alignment @@ -435,12 +520,7 @@ def get_scale_shape( block_x, block_y = self._block_dims alignment_x, alignment_y = block_alignment - if flatten_axis < 0: - flatten_axis = len(data_shape) + flatten_axis - assert ( - 0 < flatten_axis < len(data_shape) - ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" - + is_block_2d = block_x > 1 and block_y > 1 assert data_shape[flatten_axis - 1] % block_x == 0, ( f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" f" {flatten_axis - 1}" @@ -449,6 +529,9 @@ def get_scale_shape( data_shape[-1] % block_y == 0 ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" + if broadcast_2d_scale_shape_to_1d and is_block_2d: + block_x = 1 + flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) @@ -575,6 +658,7 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + broadcast_2d_scale_shape_to_1d, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -582,30 +666,41 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ + # TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed input_rank = len(input_shape) input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank - # This implementation needs to be updated for different block dims. - assert self._block_dims == (1, 32) + assert ( + self._block_dims[1] != 1 + ), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}" + # For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d + if self._block_dims[0] != 1: + assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, ( + f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d}," + f" _block_dims={self._block_dims}" + ) + + block_size_1d = self._block_dims[1] # We have to use two different factors in the two CompoundFactors because of Shardy # verifier requirements, even though they are the same. blocksizes = {} colwise_var = f"{unique_var}_None" rowwise_var = f"{unique_var}_None" - if not input_shape[-1] == 32: + if not input_shape[-1] == block_size_1d: rowwise_var = input_spec[-1] + "_compound" input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") - blocksizes["blocksize_x"] = 32 - if not input_shape[flatten_axis - 1] == 32: + blocksizes["blocksize_x"] = block_size_1d + if not input_shape[flatten_axis - 1] == block_size_1d: colwise_var = input_spec[flatten_axis - 1] + "_compound" input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") - blocksizes["blocksize_y"] = 32 + blocksizes["blocksize_y"] = block_size_1d # The rowwise and colwise scale tensors should be sharded the same way as the input. # However, we need to adjust the dimensions where the block scaling factor applies. @@ -632,6 +727,8 @@ class ScalingMode(Enum): - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales + - NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales + - NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales - NO_SCALING: No scaling applied """ @@ -639,6 +736,8 @@ class ScalingMode(Enum): DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING + NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING + NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -662,40 +761,79 @@ def get_scale_dtype(self): """ return self._get_impl().get_scale_dtype() - def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]: + def get_scale_shape_2x( + self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False + ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor is_padded: Whether to use padded shapes flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ + data_layout = self._get_impl().get_data_layout() + rowwise_layout = data_layout[0] + assert ( + rowwise_layout == "N" + ), f"For rowwise layout only 'N' is supported, received {rowwise_layout}" + colwise_layout = data_layout[1] + rowwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis + data_shape, + data_layout=rowwise_layout, + is_colwise=False, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, ) + + colwise_data_shape = data_shape + if colwise_layout == "T": + colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis] colwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis + colwise_data_shape, + data_layout=colwise_layout, + is_colwise=True, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, ) return (rowwise_scale_shape, colwise_scale_shape) def get_scale_shape( - self, data_shape, is_colwise, is_padded=True, flatten_axis=-1 + self, + data_shape, + data_layout="N", + is_colwise=False, + is_padded=True, + flatten_axis=-1, + broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Shape of the data tensor + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The shape for scale tensors """ - return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + return self._get_impl().get_scale_shape( + data_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, + ) def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: """Get the quantize layout for the tensor usage. @@ -713,6 +851,7 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis=-1, + broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. @@ -720,11 +859,14 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_shape, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules( + input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d + ) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 @@ -798,8 +940,64 @@ def is_1d_block_scaling(self) -> bool: Returns: True if the scaling mode is 1D block scaling, False otherwise """ + # Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM. + return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling + + @property + def is_block_scaling(self) -> bool: + """Check if this scaling mode is block scaling. + + Returns: + True if the scaling mode is block scaling, False otherwise + """ + # Currently we only have 1D block scaling modes + return self.is_1d_block_scaling() + + def get_compatible_q_dtypes(self) -> set[jnp.dtype]: + """Returns a set of compatible quantized data types for this scaling mode. + + Returns: + A set of compatible quantized data types + """ + if self in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ScalingMode.MXFP8_1D_SCALING, + ): + return {jnp.float8_e5m2, jnp.float8_e4m3fn} + if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING): + return {jnp.float4_e2m1fn} + if self == ScalingMode.NO_SCALING: + return {jnp.float16, jnp.bfloat16, jnp.float32} + raise ValueError(f"Invalid scaling mode: {self}") + + @property + def is_nvfp4_scaling(self) -> bool: + """Check if this scaling mode is NVFP4 scaling. + + Returns: + True if the scaling mode is NVFP4 scaling, False otherwise + """ + return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING) + + @property + def is_mxfp8_scaling(self) -> bool: + """Check if this scaling mode is NVFP4 scaling. + + Returns: + True if the scaling mode is NVFP4 scaling, False otherwise + """ return self == ScalingMode.MXFP8_1D_SCALING + @property + def is_colwise_transposed(self) -> bool: + """Check if this scaling mode uses transposed layout for column-wise scaling. + + Returns: + True if the scaling mode uses transposed layout for column-wise scaling, False otherwise + """ + return self.is_tensor_scaling() or self.is_nvfp4_scaling + def __eq__(self, other): """Compare this scaling mode with another. @@ -836,9 +1034,20 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), - # WAR + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(1, 32), + scale_dtype=jnp.float8_e8m0fnu, + data_layout="NN", + ), ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), + ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(1, 16), + scale_dtype=jnp.float8_e4m3fn, + data_layout="NT", + ), + ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT" + ), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index dbbac4abc..2d2d78190 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -201,13 +201,32 @@ def __post_init__(self): else: unpadded_scale_shape = self.scaling_mode.get_scale_shape( self.data.shape, + data_layout=self.data_layout, is_colwise=self.is_colwise, is_padded=False, - flatten_axis=self.flatten_axis, + # expect the flatten_axis wrt the N layout + flatten_axis=( + self.flatten_axis + if self.data_layout == "N" + else self.data.ndim - self.flatten_axis + ), ) - assert self.scale_inv.shape == unpadded_scale_shape, ( - "Unpadded inverse scale factor has wrong shape, expected" - f" {unpadded_scale_shape} but got {self.scale_inv.shape}." + unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( + self.data.shape, + data_layout=self.data_layout, + is_colwise=self.is_colwise, + is_padded=False, + # expect the flatten_axis wrt the N layout + flatten_axis=( + self.flatten_axis + if self.data_layout == "N" + else self.data.ndim - self.flatten_axis + ), + broadcast_2d_scale_shape_to_1d=True, + ) + assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or" + f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." ) def tree_flatten(self): @@ -583,6 +602,7 @@ def create_2x( colwise_data, colwise_scale_inv, amax=None, + colwise_amax=None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, data_layout="NN", @@ -612,6 +632,8 @@ def create_2x( """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) + if colwise_amax is None: + colwise_amax = amax assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" rowwise_tensor = ScaledTensorFactory.create_1x( @@ -630,10 +652,10 @@ def create_2x( colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, - amax, + colwise_amax, scaling_mode, dq_dtype, - is_colwise=True, + is_colwise=True, # TODO(Phuong): set this correctly data_layout=data_layout[1], flatten_axis=flatten_axis, group_sizes=group_sizes, @@ -649,6 +671,7 @@ def create( colwise_data: jnp.ndarray, colwise_scale_inv: jnp.ndarray, amax=None, + colwise_amax=None, scaling_mode: ScalingMode = ScalingMode.NO_SCALING, dq_dtype: jnp.dtype = jnp.bfloat16, data_layout: str = "NN", @@ -684,6 +707,7 @@ def create( colwise_data, colwise_scale_inv, amax, + colwise_amax, scaling_mode, dq_dtype, data_layout=data_layout, @@ -698,7 +722,7 @@ def create( return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, - amax, + colwise_amax if colwise_amax is not None else amax, scaling_mode, dq_dtype, is_colwise=is_colwise, From dd9433e7ad28c12f27da9770be54c9c584e85fa0 Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Thu, 9 Oct 2025 17:30:35 -0600 Subject: [PATCH 046/141] Don't pickle an empty dict in LayerNorm and pt base modules (#2253) Don't pickle an empty dict in LayerNorm and BasicOperation layers Signed-off-by: Peter St. John Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 103ebf241..095e3e89e 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -595,6 +595,9 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: extra[key] = val state[mode]["extra_fp8_variables"] = extra + if not state: + return torch.empty(0, dtype=torch.uint8) + # Serialize state into byte tensor torch.cuda.synchronize() state_serialized = bytearray(pickle.dumps(state)) From 7ad130efd52c3aa4a386d25f1d42b28d5aa20090 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 13 Oct 2025 01:10:38 -0700 Subject: [PATCH 047/141] Offloading support for multiple attention layouts (#2024) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added multi-layout support for attention Signed-off-by: Selvaraj Anandaraj * Comment/cleanup Signed-off-by: Selvaraj Anandaraj * Bug fix on import time Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Pawel Gadzinski Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Pawel Gadzinski Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- .../dot_product_attention/backends.py | 27 +++++++++++++++++-- transformer_engine/pytorch/cpu_offload.py | 13 +++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 0ddb261d2..d75481ad9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1258,7 +1258,6 @@ def forward( else: tensor_list = [q, k, v, out] - qkv_layout = "sbhd_sbhd_sbhd" mark_activation_offload(*tensor_list) mark_activation_offload(*aux_ctx_tensors) @@ -1293,7 +1292,31 @@ def forward( ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout + + from transformer_engine.pytorch.cpu_offload import ( + CPUOffloadedLayer, + ) + + # If interleaved tensor is offloaded, reloaded tensor will be + # non-interleaved, so we need to modify the QKV layout + # for backward + if CPUOffloadedLayer and CPUOffloadEnabled: + reload_layout = "" + split_list = qkv_layout.split("_") + for split in split_list: + temp_layout = "" + rep_count = 1 + for s in split: + if s.isalpha(): + temp_layout = temp_layout + s + else: + rep_count = int(s) + for _ in range(rep_count): + reload_layout = reload_layout + temp_layout + "_" + ctx.qkv_layout = reload_layout[:-1] + else: + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type ctx.softmax_type = softmax_type diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 9378774ea..648b21eb4 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -16,6 +16,7 @@ __all__ = ["get_cpu_offload_context"] CPUOffloadEnabled = False +CPUOffloadedLayer = False def mark_activation_offload(*tensors): @@ -353,6 +354,7 @@ def __init__( self.h2d_stream = torch.cuda.Stream() def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + global CPUOffloadedLayer torch_stray_tensor = isinstance( tensor, @@ -408,6 +410,11 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: tensor.clear() else: self.tensor_tag_to_buf[tensor_tag] = t + + # Needed to differentiate non offloaded layer's attention + # QKV layout of attention of non-offloaded layer needs + # to be modified while reloading + CPUOffloadedLayer = True else: tensor_tag = (-1, self.torch_tensor_count) self.torch_tensor_count += 1 @@ -417,6 +424,8 @@ def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: def tensor_pop(self, tensor_tag, **kwargs): """Tensor pop.""" + global CPUOffloadedLayer + assert tensor_tag in self.tensor_tag_to_state tensor = self.tensor_tag_to_state.pop(tensor_tag) @@ -480,6 +489,7 @@ def bulk_offload_group(self, group_to_offload): def synchronize_on_group_commit_forward(self, current_group): """Synchronize on group commit forward.""" + global CPUOffloadedLayer # For the first group, kickstart the offload after we have # the first compute completion @@ -528,6 +538,9 @@ def synchronize_on_group_commit_forward(self, current_group): # Increment the offload group count to keep track self.offloaded_group_count += 1 + if current_group == (self.num_offload_group - 1): + CPUOffloadedLayer = False + if not self.double_buffer_created: # Creating second copy of double buffer for tensors that are offloaded if current_group == (self.num_layers - 1): From 8eec2004f1abf6bc6083fc13820c117c39be80ad Mon Sep 17 00:00:00 2001 From: "Peter St. John" Date: Mon, 13 Oct 2025 11:23:18 -0600 Subject: [PATCH 048/141] Disable torch autocast context in rope forward pass (#2240) Signed-off-by: Peter St. John Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fused_rope.py | 16 +++++++++ transformer_engine/pytorch/attention/rope.py | 38 +++++++++++--------- 2 files changed, 37 insertions(+), 17 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 62d80b552..aaf2eca2d 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -373,3 +373,19 @@ def test_fused_qkv_rope( if not isinstance(start_positions, torch.Tensor): torch.testing.assert_close(grad_fused, grad_unfused) + + +def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast(): + rope_layer = RotaryPositionEmbedding(128) + + rope_embeddings_no_autocast = rope_layer(max_seq_len=1024) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + rope_embeddings_autocast = rope_layer(max_seq_len=1024) + + torch.testing.assert_close( + rope_embeddings_no_autocast.to(dtype=torch.bfloat16), + rope_embeddings_autocast.to(dtype=torch.bfloat16), + atol=1e-8, + rtol=1e-8, + ) diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 139381f2d..cc23d65a3 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -66,6 +66,9 @@ def forward(self, max_seq_len: int, offset: int = 0): """ Create rotary position embedding frequencies. + This function is particularly sensitive to the use of mixed precision, so we disable the + autocast context if it is enabled. + Parameters ---------- max_seq_len: int @@ -73,26 +76,27 @@ def forward(self, max_seq_len: int, offset: int = 0): offset: int, default = 0 Fixed offset for frequencies. """ - seq = ( - torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - + offset - ) + with torch.autocast(enabled=False, device_type="cuda"): + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) - if ( - self.pretrained_max_position_embeddings is not None - and self.seq_len_interpolation_factor is not None - ): if ( - max_seq_len - > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + self.pretrained_max_position_embeddings is not None + and self.seq_len_interpolation_factor is not None ): - # dynamic linear scaling (length > position we have learned) - seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) - else: - # fixed linear scaling - seq *= 1 / self.seq_len_interpolation_factor - - freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) + if ( + max_seq_len + > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + ): + # dynamic linear scaling (length > position we have learned) + seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) + else: + # fixed linear scaling + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size if not self.interleaved: From 8c364b4d5915290c29e337e9157426a1a15bf658 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Mon, 13 Oct 2025 13:00:11 -0700 Subject: [PATCH 049/141] [Common][JAX] Improve error message for cublas fp8 gemm with incorrect shape (#2261) * Improve error message for cublas fp8 gemm with incorrect shape Signed-off-by: Jeremy Berchtold * lint Signed-off-by: Jeremy Berchtold * Removed unnecessary non-contracting size check Signed-off-by: Jeremy Berchtold * rename inner dim -> leading dim Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- .../common/gemm/cublaslt_gemm.cu | 14 +++++- transformer_engine/jax/cpp_extensions/gemm.py | 49 ++++++++++++++++++- 2 files changed, 60 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index a4810881c..84a1b735a 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -141,6 +141,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } } + + if (is_fp8_dtype(ret.Atype)) { + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK(ret.lda % 16 == 0, + "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + } } else if (nvfp4) { // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. @@ -187,7 +193,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK((ret.lda % 16) == 0, - "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. NVTE_CHECK((m % 8) == 0, @@ -216,6 +222,12 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } } + + if (is_fp8_dtype(ret.Atype)) { + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK(ret.ldb % 16 == 0, + "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + } } else if (nvfp4) { if (is_B_transposed) { NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b72161f1a..b37c4bd84 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -360,6 +360,28 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) +def get_lhs_axis_boundary(lhs_cdims, is_transposed): + """Get the axis boundary for the LHS operand.""" + return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims) + + +def get_rhs_axis_boundary(rhs_cdims, is_transposed): + """Get the axis boundary for the RHS operand.""" + return min(rhs_cdims) if is_transposed else max(rhs_cdims) + 1 + + +def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): + """Assert that the given tensor shape and layout meet the requirements for cuBLAS GEMM.""" + if scaling_mode != ScalingMode.NO_SCALING: + # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + alignment = 32 if scaling_mode.is_nvfp4_scaling else 16 + + assert contracting_size % alignment == 0, ( + f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" + f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" + ) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -452,6 +474,29 @@ def _dims_are_consecutive(dims): f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" ) + lhs_axis_boundary = get_lhs_axis_boundary(lhs_contracting_dims, lhs_is_transposed) + lhs_contracting_size = ( + reduce(operator.mul, lhs.shape[lhs_axis_boundary:]) + if lhs_is_transposed + else reduce(operator.mul, lhs.shape[:lhs_axis_boundary]) + ) + assert_cublas_requirements( + scaling_mode, + lhs_contracting_size, + "LHS", + ) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_contracting_dims, rhs_is_transposed) + rhs_contracting_size = ( + reduce(operator.mul, rhs.shape[:rhs_axis_boundary]) + if rhs_is_transposed + else reduce(operator.mul, rhs.shape[rhs_axis_boundary:]) + ) + assert_cublas_requirements( + scaling_mode, + rhs_contracting_size, + "RHS", + ) + # Determine output shape and dtype assert ( dtypes.canonicalize_dtype(out_dtype).itemsize > 1 @@ -563,8 +608,8 @@ def lowering( args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { "scaling_mode": int(scaling_mode.value), - "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_axis_boundary": get_lhs_axis_boundary(lhs_cdims, lhs_transposed), + "rhs_axis_boundary": get_rhs_axis_boundary(rhs_cdims, rhs_transposed), "lhs_transposed": lhs_transposed, "rhs_transposed": rhs_transposed, "fuse_bias": fuse_bias, From 76e1af33401f9631851fc9b8d8bd35ff4b959da5 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Mon, 13 Oct 2025 14:25:19 -0700 Subject: [PATCH 050/141] [JAX] Add assertion message to amax -> scale computation (#2263) assertion check Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/quantize/quantizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 7198014f2..7bc08f834 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -67,7 +67,7 @@ def compute_scale_from_amax( sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) - assert sf.shape == (1,) + assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" return sf From a3b749b18609f1715fbc97dc7d2639a2c2510248 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Mon, 13 Oct 2025 16:27:49 -0600 Subject: [PATCH 051/141] FSDP grad fusion support (#2191) * FSDP grad fusion support Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Re-factored grad overwriting usage Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/ops/basic/basic_linear.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/ops/fused/backward_linear_add.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/ops/fused/backward_linear_scale.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Selvaraj Anandaraj * Update transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Selvaraj Anandaraj * Modified API usage, added arg details Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/module/grouped_linear.py | 10 ++++++++-- .../pytorch/module/layernorm_linear.py | 10 ++++++++-- .../pytorch/module/layernorm_mlp.py | 16 +++++++++++++--- transformer_engine/pytorch/module/linear.py | 10 ++++++++-- .../pytorch/ops/basic/basic_linear.py | 5 ++++- .../pytorch/ops/fused/backward_linear_add.py | 1 + .../pytorch/ops/fused/backward_linear_scale.py | 1 + .../ops/fused/userbuffers_backward_linear.py | 1 + 8 files changed, 44 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b3adfb7db..ec05f684b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -402,7 +402,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias if grad_biases[0] is None else None, bias=biases, use_split_accumulator=wgrad_gemm_use_split_accumulator, - accumulate=accumulate_wgrad_into_param_main_grad, + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): @@ -519,7 +523,9 @@ class GroupedLinear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index e1c0eab2d..0559ae7ce 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -849,7 +849,11 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None), @@ -1125,7 +1129,9 @@ class LayerNormLinear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index d680a9f8f..8ef19d052 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -948,7 +948,11 @@ def backward( else ctx.activation_dtype ), "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(fc1_weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, @@ -1189,7 +1193,11 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ctx.fc1_grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(fc2_weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, @@ -1484,7 +1492,9 @@ class LayerNormMLP(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 02872439a..67124c157 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -843,7 +843,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None), @@ -1061,7 +1065,9 @@ class Linear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index cb2119296..b15d840d6 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -80,7 +80,9 @@ class BasicLinear(BasicOperation): autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be meaningful. This is primarily intented to integrate with - Megatron-LM. + Megatron-LM. This argument along with weight tensor having + attribute 'overwrite_main_grad' set to True will overwrite + `main_grad` instead of accumulating. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -1019,6 +1021,7 @@ def op_backward( weight_param = self.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 845ba262a..a86745a68 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -59,6 +59,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index a9595d516..832e51de8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -60,6 +60,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 1ecdba625..d95b2298f 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -523,6 +523,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " From 5ec0f33b7ed8d6e2bd2e2a1be01f93c9d2fd7422 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Mon, 13 Oct 2025 18:36:00 -0700 Subject: [PATCH 052/141] [JAX] Fix test path for fp8 grouped gemm ag (#2262) Fix test path so that it gets triggered Signed-off-by: Kshitij Lakhani --- qa/L1_jax_distributed_unittest/test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 8ecc5a917..270f0df15 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -9,4 +9,4 @@ set -xe mkdir -p "$XML_LOG_DIR" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* -SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh +SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh From dfacd9f76bcabcdd53cb30a17679ad6032cf54f4 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Tue, 14 Oct 2025 10:11:01 +0200 Subject: [PATCH 053/141] [PyTorch] Use Quantization API for reference NVFP4 recipe (#2259) * Fix update_quantized in ref nvfp4 quantizer Signed-off-by: Evgeny * Subclass quantization API Signed-off-by: Evgeny * Use recipe.Custom and quantizer factories for reference NVFP4 Signed-off-by: Evgeny * Linter fix Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evgeny Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/distributed/run_numerics_exact.py | 75 +++++-- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 2 +- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 86 ++++--- .../nvfp4/test_nvfp4_quantize_exact.py | 2 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 2 +- .../dot_product_attention.py | 2 + .../pytorch/experimental/__init__.py | 5 - .../pytorch/experimental/config.py | 201 ----------------- .../pytorch/experimental/gemm.py | 24 +- .../pytorch/experimental/quantization.py | 174 -------------- ...icroblock_ref.py => quantization_nvfp4.py} | 212 ++++++++++++------ transformer_engine/pytorch/module/_common.py | 28 --- .../pytorch/module/layernorm_linear.py | 8 +- transformer_engine/pytorch/module/linear.py | 8 +- 14 files changed, 286 insertions(+), 543 deletions(-) delete mode 100644 transformer_engine/pytorch/experimental/config.py rename transformer_engine/pytorch/experimental/{quantization_microblock_ref.py => quantization_nvfp4.py} (83%) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index b1722b79a..40be8e1f0 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -21,9 +21,12 @@ Format, Recipe, QParams, + CustomRecipe, ) from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.experimental import quantization_nvfp4 +from transformer_engine.pytorch.experimental import utils from run_layer_with_overlap import _compare_tensors @@ -48,6 +51,52 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def get_nvfp4_quantizer_factory(): + """ + Create a quantizer factory for NVFP4 reference implementation. + + This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization + enabled. + + Returns: + A factory function that takes a role string and returns a quantizer instance + """ + + def factory(role): + if role == "linear_input": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, # RHT enabled for input + ) + elif role == "linear_weight": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), # 2D quantization for weight + pow_2_scales=False, + with_rht=False, + ) + elif role == "linear_output": + # Output quantization not used + return None + elif role == "linear_grad_output": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, # RHT enabled for grad_output + ) + elif role == "linear_grad_input": + # Grad input quantization not used + return None + else: + # For any other roles, return None + return None + + return factory + + # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "nvfp4": @@ -55,16 +104,12 @@ def quantization_recipe() -> Recipe: raise ValueError(f"Unsupported quantization: {QUANTIZATION}") -def setup_environment_for_reference(): +def quantization_reference_recipe() -> Recipe: + """Create reference recipe using CustomRecipe with NVFP4 quantizer factory.""" if QUANTIZATION == "nvfp4": - os.environ["QAT_PARAMS"] = "9003" - else: - raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}") - - -def cleanup_environment(): - if "QAT_PARAMS" in os.environ: - del os.environ["QAT_PARAMS"] + nvfp4_ref_factory = get_nvfp4_quantizer_factory() + return CustomRecipe(qfactory=nvfp4_ref_factory) + raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}") def main(argv=None, namespace=None): @@ -478,8 +523,8 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ) # run the reference - setup_environment_for_reference() - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + reference_recipe = quantization_reference_recipe() + with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe): y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( x, w, @@ -494,8 +539,6 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): run_num_steps=run_num_steps, enable_weight_cache=enable_weight_cache, ) - # Clean up env - cleanup_environment() # compare results, zero tolerance if WORLD_RANK == 0: @@ -673,8 +716,8 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ) # run the reference - setup_environment_for_reference() - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + reference_recipe = quantization_reference_recipe() + with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe): y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( TestDistributedLayerNormLinearBase.run_layernorm_linear( x, @@ -690,8 +733,6 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs enable_weight_cache=False, ) ) - # Clean up env - cleanup_environment() # compare results, zero tolerance if WORLD_RANK == 0: diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a9e73aaf9..42837fb40 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -9,7 +9,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer -from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental import utils diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index ae9975839..1d1467640 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -2,13 +2,14 @@ # # See LICENSE for license information. -import os import pytest import torch import transformer_engine as te from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.distributed import fp8_autocast from transformer_engine.common import recipe +from transformer_engine.pytorch.experimental import quantization_nvfp4 +from transformer_engine.pytorch.experimental import utils recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() @@ -65,20 +66,54 @@ def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = Fa return GetRecipes.nvfp4_vanilla() -def setup_environment_for_reference(with_rht: bool = False, with_2d_quantization: bool = False): - if with_rht and with_2d_quantization: - os.environ["QAT_PARAMS"] = "9003" - elif with_rht: - os.environ["QAT_PARAMS"] = "960109" - elif with_2d_quantization: - os.environ["QAT_PARAMS"] = "9002" - else: - os.environ["QAT_PARAMS"] = "6010" +def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bool = False): + """ + Create a quantizer factory for NVFP4 reference implementation. + + This factory returns NVFP4QuantizerRef instances based on the role and configuration. + Used with CustomRecipe to create reference quantizers. + + Args: + with_rht: Whether to enable random Hadamard transform + with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights) + + Returns: + A factory function that takes a role string and returns a quantizer instance + """ + def factory(role): + if role == "linear_input": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=with_rht, + ) + elif role == "linear_weight": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), + pow_2_scales=False, + with_rht=False, + ) + elif role == "linear_output": + # Output quantization not used + return None + elif role == "linear_grad_output": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=with_rht, + ) + elif role == "linear_grad_input": + # Grad input quantization not used + return None + else: + # For any other roles, return None + return None -def cleanup_environment(): - if "QAT_PARAMS" in os.environ: - del os.environ["QAT_PARAMS"] + return factory def reset_rng_states(): @@ -113,7 +148,6 @@ def check_nvfp4_module_versus_reference( seq_len = 128 # Create both modules with identical initialization - cleanup_environment() reset_rng_states() # Create native module @@ -138,7 +172,6 @@ def check_nvfp4_module_versus_reference( raise ValueError(f"Unsupported module class: {module_class}") # Create reference module with same weights - setup_environment_for_reference(with_rht, with_2d_quantization) reset_rng_states() # Create reference module @@ -174,7 +207,10 @@ def check_nvfp4_module_versus_reference( if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + # Create recipes for native and reference implementations nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization) + nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory) # Training loop comparison native_outputs = [] @@ -196,17 +232,13 @@ def check_nvfp4_module_versus_reference( grad_output = grad_output_val.clone().detach() # Native forward/backward - cleanup_environment() with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): # enable weight cache by giving is_first_microbatch y_native = native_module(x_native, is_first_microbatch=(step == 0)) y_native.backward(grad_output) # Reference forward/backward - setup_environment_for_reference(with_rht, with_2d_quantization) - with fp8_autocast( - enabled=True, fp8_recipe=nvfp4_recipe - ): # Exact recipe does not play a role here + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe): y_ref = ref_module(x_ref) y_ref.backward(grad_output) @@ -295,9 +327,6 @@ def check_nvfp4_module_versus_reference( msg=f"Bias gradient mismatch at step {step}", ) - # Clean up - cleanup_environment() - @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( @@ -362,7 +391,6 @@ def check_nvfp4_layernorm_linear_versus_reference( seq_len = 128 # Create both modules with identical initialization - cleanup_environment() reset_rng_states() # Native module @@ -377,7 +405,6 @@ def check_nvfp4_layernorm_linear_versus_reference( ) # Reference module - setup_environment_for_reference(with_rht, with_2d_quantization) reset_rng_states() ref_module = te.pytorch.LayerNormLinear( in_features=in_features, @@ -405,7 +432,10 @@ def check_nvfp4_layernorm_linear_versus_reference( if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None: ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + # Create recipes for native and reference implementations nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization) + nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory) native_outputs = [] ref_outputs = [] @@ -426,14 +456,12 @@ def check_nvfp4_layernorm_linear_versus_reference( grad_output = grad_output_val.clone().detach() # Native forward/backward - cleanup_environment() with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) y_native.backward(grad_output) # Reference forward/backward - setup_environment_for_reference(with_rht, with_2d_quantization) - with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe): y_ref, ln_out_ref = ref_module(x_ref) y_ref.backward(grad_output) @@ -515,8 +543,6 @@ def check_nvfp4_layernorm_linear_versus_reference( msg=f"Bias gradient mismatch at step {step}", ) - cleanup_environment() - @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index dc3c4a4e9..cdcb2df1d 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -12,7 +12,7 @@ from transformer_engine.pytorch.tensor.nvfp4_tensor import ( NVFP4Quantizer, ) -from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index bb542456e..494fa63c0 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -20,7 +20,7 @@ from transformer_engine.pytorch.tensor.nvfp4_tensor import ( NVFP4Quantizer, ) -from transformer_engine.pytorch.experimental.quantization_microblock_ref import NVFP4QuantizerRef +from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental import utils from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index df96067d6..0a8802fb0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -546,6 +546,8 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in fp8_autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.custom(): + return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. diff --git a/transformer_engine/pytorch/experimental/__init__.py b/transformer_engine/pytorch/experimental/__init__.py index 11658f636..6e859ba5d 100644 --- a/transformer_engine/pytorch/experimental/__init__.py +++ b/transformer_engine/pytorch/experimental/__init__.py @@ -3,8 +3,3 @@ # See LICENSE for license information. """Experimental features and APIs.""" - -from .config import set_qlinear_params, get_experimental_quantizers - - -__all__ = ["set_qlinear_params", "get_experimental_quantizers"] diff --git a/transformer_engine/pytorch/experimental/config.py b/transformer_engine/pytorch/experimental/config.py deleted file mode 100644 index fec6bc938..000000000 --- a/transformer_engine/pytorch/experimental/config.py +++ /dev/null @@ -1,201 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -"""Config API for experimental middleware between Transformer Engine and Kitchen.""" - -import dataclasses -import enum -import os -from typing import Optional - -from transformer_engine.pytorch.experimental import utils -from transformer_engine.pytorch.experimental import quantization -from transformer_engine.pytorch.experimental import quantization_microblock_ref -from transformer_engine.pytorch.experimental.quantization import MMParams - - -@dataclasses.dataclass() -class QLinearParams: - """Quantization parameters of linear layer. - - Contains ready-to-use quantizers for input (x), weight (w), and gradient (g) tensors. - """ - - x_quantizer: Optional[quantization.ExperimentalQuantizer] = None - w_quantizer: Optional[quantization.ExperimentalQuantizer] = None - g_quantizer: Optional[quantization.ExperimentalQuantizer] = None - - mm_fprop: Optional[MMParams] = None - mm_dgrad: Optional[MMParams] = None - mm_wgrad: Optional[MMParams] = None - - -@enum.unique -class QuantizeRecipe(enum.Enum): - """Pre-defined quantization recipes for linear layers.""" - - NON_QUANTIZE = "non_quantize" - NVFP4_REF = "nvfp4_ref" - NVFP4_REF_RHT_ONLY = "nvfp4_ref_rht_only" - NVFP4_REF_2D_QUANTIZATION_ONLY = "nvfp4_ref_2d_quantization_only" - NVFP4_REF_RHT_AND_2D_QUANTIZATION = "nvfp4_ref_rht_and_2d_quantization" - - -def get_qlinear_params_from_predefined( - recipe: QuantizeRecipe, -) -> Optional[QLinearParams]: - """Get quantization parameters for linear layer based on recipe.""" - if recipe == QuantizeRecipe.NON_QUANTIZE: - return None - if recipe == QuantizeRecipe.NVFP4_REF: - return QLinearParams( - x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - ), - w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - ), - g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - ), - ) - if recipe == QuantizeRecipe.NVFP4_REF_RHT_ONLY: - return QLinearParams( - x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ), - w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=False, - ), - g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ), - ) - if recipe == QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY: - return QLinearParams( - x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=False, - ), - w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(16, 16), - pow_2_scales=False, - with_rht=False, - ), - g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=False, - ), - ) - if recipe == QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION: - return QLinearParams( - x_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ), - w_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(16, 16), - pow_2_scales=False, - with_rht=False, - ), - g_quantizer=quantization_microblock_ref.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ), - ) - raise ValueError(f"Unsupported quantize recipe: {recipe}") - - -def get_qlinear_params_from_qat_params(qat_params_idx: int) -> Optional[QLinearParams]: - """Load quantization options from Kitchen to Transformer Engine. - - TODO(etsykunov): Confirm docstring is correct. - """ - assert qat_params_idx > 0, "QAT_PARAMS is not set." - - if qat_params_idx == 6010: - return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF) - if qat_params_idx == 960109: - return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_ONLY) - if qat_params_idx == 9002: - return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_2D_QUANTIZATION_ONLY) - if qat_params_idx == 9003: - return get_qlinear_params_from_predefined(QuantizeRecipe.NVFP4_REF_RHT_AND_2D_QUANTIZATION) - raise ValueError(f"Unsupported QAT params index: {qat_params_idx}") - - -def set_qlinear_params( - qlinear_params: Optional[QLinearParams] = None, - layer_number: Optional[int] = None, - layer_name: Optional[str] = None, -) -> Optional[QLinearParams]: - """Set quantization parameters based on configuration. - - Args: - qlinear_params: Quantization parameters. If None, loaded from environment. - layer_number: The numerical index of this layer in the model structure. - layer_name: The name for this layer. - - Returns: - QLinearParams: The finalized quantization parameters for this layer. - """ - if qlinear_params is None: - qat_params_idx = int(os.getenv("QAT_PARAMS", "0")) - if qat_params_idx == 0: - return None - return get_qlinear_params_from_qat_params(qat_params_idx) - - # Apply layer-specific overrides - if layer_number is not None: - raise NotImplementedError("Layer-specific overrides are not supported yet.") - if layer_name is not None: - raise NotImplementedError("Layer-specific overrides are not supported yet.") - - return qlinear_params - - -def get_experimental_quantizers(fp8: bool, qlinear_params: QLinearParams): - """Replacement of _get_quantizers() in TE modules.""" - if not fp8: - raise ValueError("FP8 is required to be enabled for experimental quantization.") - input_quantizer = qlinear_params.x_quantizer - weight_quantizer = qlinear_params.w_quantizer - output_quantizer = None - grad_input_quantizer = None - grad_weight_quantizer = None - grad_output_quantizer = qlinear_params.g_quantizer - - return ( - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, - ) diff --git a/transformer_engine/pytorch/experimental/gemm.py b/transformer_engine/pytorch/experimental/gemm.py index d743b577b..0bd740d85 100644 --- a/transformer_engine/pytorch/experimental/gemm.py +++ b/transformer_engine/pytorch/experimental/gemm.py @@ -11,14 +11,14 @@ from transformer_engine.pytorch.experimental.quantization import ( MMParams, GEMMType, - ExperimentalQuantizedTensor, ) -from transformer_engine.pytorch.tensor.quantized_tensor import Quantizer +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.tensor.utils import is_experimental def experimental_gemm( - A: ExperimentalQuantizedTensor, - B: ExperimentalQuantizedTensor, + A: QuantizedTensorStorage, + B: QuantizedTensorStorage, workspace: torch.Tensor, # pylint: disable=unused-argument out_dtype: Optional[torch.dtype] = None, quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument @@ -32,9 +32,7 @@ def experimental_gemm( grad: bool = False, ) -> Iterable[Optional[torch.Tensor]]: """Dispatch GEMM to quantizer's qgemm method.""" - assert isinstance(A, ExperimentalQuantizedTensor) and isinstance( - B, ExperimentalQuantizedTensor - ), "A and B must be ExperimentalQuantizedTensor instances" + assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors" A, B = B, A @@ -51,14 +49,14 @@ def experimental_gemm( gemm_type = GEMMType.FPROP # Extract quantizer from QuantizedTensor to get qgemm logic - # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B.quantizer? + # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B._quantizer? quantizer = None - if hasattr(A, "quantizer") and A.quantizer is not None: - quantizer = A.quantizer - elif hasattr(B, "quantizer") and B.quantizer is not None: - quantizer = B.quantizer + if hasattr(A, "_quantizer") and A._quantizer is not None: + quantizer = A._quantizer + elif hasattr(B, "_quantizer") and B._quantizer is not None: + quantizer = B._quantizer else: - raise ValueError("No quantizer found in QuantizedETensor objects") + raise ValueError("No quantizer found in QuantizedTensor objects") # Create MMParams m_params = MMParams( diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/experimental/quantization.py index 7d573abac..876ca7fcb 100644 --- a/transformer_engine/pytorch/experimental/quantization.py +++ b/transformer_engine/pytorch/experimental/quantization.py @@ -5,17 +5,11 @@ """Quantization API for experimental middleware between Transformer Engine and Kitchen.""" from __future__ import annotations -import abc import dataclasses import enum -from typing import Iterable, Optional, Tuple, Union import torch -from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer -from transformer_engine.pytorch.experimental import utils - @enum.unique class GEMMType(enum.Enum): @@ -33,171 +27,3 @@ class MMParams: out_dtype: torch.dtype | None = None # Use split accumulator for more accurate FP8 GEMM use_split_accumulator: bool = True - - -@dataclasses.dataclass -class ExperimentalQuantizedTensor(QuantizedTensorStorage): - """Base class for experimental quantized tensor containers. - - An experimental container to hold quantization result, including quantized tensor, optional - transposed quantized tensor, and corresponding decoding scales. - - data: torch.Tensor - the quantized tensor. - scale: torch.Tensor - the decoding scale for the quantized tensor. Shape depends on the scaling granularity. - - if scaling type is PER_TENSOR, it should be a 1D scalar tensor. - data_t: torch.Tensor - the transposed quantized tensor (computed lazily if needed). - scale_t: torch.Tensor - the decoding scale for the transposed quantized tensor. - dtype: torch.dtype - nominal tensor datatype. - device: torch.device - device of the tensor. - quant_dtype: Union[utils.Fp4Formats, torch.dtype] - low precision tensor datatype. - original_shape: Tuple[int, ...] - original shape of the tensor. - quantizer: ExperimentalQuantizer - Builder class for quantized tensor. - """ - - data: Optional[torch.Tensor] = None - scale: Optional[torch.Tensor] = None - data_t: Optional[torch.Tensor] = None - scale_t: Optional[torch.Tensor] = None - global_amax_row: Optional[torch.Tensor] = None - global_amax_col: Optional[torch.Tensor] = None - - dtype: Optional[torch.dtype] = None - device: Optional[torch.device] = None - quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None - original_shape: Optional[Tuple[int, ...]] = None - quantizer: Optional[ExperimentalQuantizer] = None - - @property - def experimental(self) -> bool: - """Flag to indicate this quantizer is using experimental Kitchen middleware.""" - return True - - def get_quantizer(self) -> ExperimentalQuantizer: - """Get builder for QuantizedExperimentalTensor - - Quantizer can be used for in-place operations. - - """ - if self.quantizer is not None: - return self.quantizer - raise ValueError("Quantizer is not set") - - def prepare_for_saving( - self, - ) -> Tuple[list[Optional[torch.Tensor]], ExperimentalQuantizedTensor]: - """Prepare the quantization result for saving for backward""" - tensors = [self.data, self.data_t, self.scale, self.scale_t] - self.data = None - self.data_t = None - self.scale = None - self.scale_t = None - return tensors, self - - def restore_from_saved( - self, tensors: list[Optional[torch.Tensor]] - ) -> list[Optional[torch.Tensor]]: - """Restore the quantization result from the saved tensors""" - self.data = tensors[0] - self.data_t = tensors[1] - self.scale = tensors[2] - self.scale_t = tensors[3] - return tensors[4:] - - def dequantize(self, *args, **kwargs) -> torch.Tensor: - """Dequantize the quantized tensor""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement dequantize function" - ) - - # Compatibility - @property - def _data(self): - return self.data - - @_data.setter - def _data(self, value): - self.data = value - - @property - def _scale_inv(self): - return self.scale - - @_scale_inv.setter - def _scale_inv(self, value): - self.scale = value - - -class ExperimentalQuantizer(Quantizer): - """Experimental Quantizer class - - Defines the interface for experimental quantizers. - """ - - def __init__(self, *, rowwise: bool, columnwise: bool) -> None: - super().__init__(rowwise=rowwise, columnwise=columnwise) - self.internal = True - - @property - def experimental(self) -> bool: - """Flag to indicate this quantizer is using experimental Kitchen middleware""" - return True - - @abc.abstractmethod - def qgemm( - self, - qx: torch.Tensor, - qw: torch.Tensor, - m_params: MMParams, - out_dtype: torch.dtype, - sx: torch.Tensor, - sw: torch.Tensor, - bias: torch.Tensor | None = None, - out: torch.Tensor | None = None, - accumulate: bool = False, - gemm_type: GEMMType = GEMMType.FPROP, - qresult_x: ExperimentalQuantizedTensor | None = None, - qresult_w: ExperimentalQuantizedTensor | None = None, - ) -> torch.Tensor: - """Quantized GEMM interface.""" - - def dequantize(self, *args, **kwargs) -> torch.Tensor: - """Dequantize the quantized tensor""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement dequantize function" - ) - - def update_quantized(self, *args, **kwargs) -> torch.Tensor: - """Update the quantized tensor with the given tensor in-place""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement update_quantized function" - ) - - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - ) -> QuantizedTensorStorage: - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement make_empty function" - ) - - def calibrate(self, tensor: torch.Tensor) -> None: - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement calibrate function" - ) - - def _get_compatible_recipe(self) -> Union[type[Recipe], None]: - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement _get_compatible_recipe function" - ) diff --git a/transformer_engine/pytorch/experimental/quantization_microblock_ref.py b/transformer_engine/pytorch/experimental/quantization_nvfp4.py similarity index 83% rename from transformer_engine/pytorch/experimental/quantization_microblock_ref.py rename to transformer_engine/pytorch/experimental/quantization_nvfp4.py index da749d237..fc50d0742 100644 --- a/transformer_engine/pytorch/experimental/quantization_microblock_ref.py +++ b/transformer_engine/pytorch/experimental/quantization_nvfp4.py @@ -2,18 +2,49 @@ # # See LICENSE for license information. -"""NVFP4 implementations for experimental middleware between Transformer Engine and Kitchen.""" +"""NVFP4 recipe reference implementation.""" -from typing import Optional, Tuple +import dataclasses +from typing import Optional, Tuple, Union import torch from transformer_engine.pytorch.experimental import quantization from transformer_engine.pytorch.experimental import utils -from transformer_engine.pytorch.experimental.quantization import ( - ExperimentalQuantizedTensor, - ExperimentalQuantizer, -) +from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer + + +def nvfp4_ref_rht_2d_quantizer_factory(role): + """ + Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). + + Usage with CustomRecipe and fp8_autocast: + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) + with fp8_autocast(fp8_recipe=custom_recipe): + output = model(input) + """ + if role == "linear_input": + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) + if role == "linear_weight": + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ) + if role == "linear_grad_output": + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) + return None def cast_to_fp4x2(x): @@ -156,8 +187,89 @@ def high_precision_gemm_ref( return y_ref -class NVFP4TensorRef(ExperimentalQuantizedTensor): - """NVFP4 tensor for middleware between Transformer Engine and Kitchen""" +@dataclasses.dataclass +class NVFP4TensorRef(QuantizedTensorStorage): + """NVFP4 tensor for middleware between Transformer Engine and Kitchen. + + Custom container to hold quantization result, including quantized tensor, optional + transposed quantized tensor, and corresponding decoding scales. + + data: torch.Tensor + the quantized tensor. + scale: torch.Tensor + the decoding scale for the quantized tensor. Shape depends on the scaling granularity. + - if scaling type is PER_TENSOR, it should be a 1D scalar tensor. + data_t: torch.Tensor + the transposed quantized tensor (computed lazily if needed). + scale_t: torch.Tensor + the decoding scale for the transposed quantized tensor. + dtype: torch.dtype + nominal tensor datatype. + device: torch.device + device of the tensor. + quant_dtype: Union[utils.Fp4Formats, torch.dtype] + low precision tensor datatype. + original_shape: Tuple[int, ...] + original shape of the tensor. + _quantizer: Quantizer + Builder class for quantized tensor. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + data_t: Optional[torch.Tensor] = None + scale_t: Optional[torch.Tensor] = None + global_amax_row: Optional[torch.Tensor] = None + global_amax_col: Optional[torch.Tensor] = None + + dtype: Optional[torch.dtype] = None + device: Optional[torch.device] = None + quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None + original_shape: Optional[Tuple[int, ...]] = None + _quantizer: Optional[Quantizer] = None + + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware.""" + return True + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: + """Prepare the quantization result for saving for backward""" + tensors = [self.data, self.data_t, self.scale, self.scale_t] + self.data = None + self.data_t = None + self.scale = None + self.scale_t = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the quantization result from the saved tensors""" + self.data = tensors[0] + self.data_t = tensors[1] + self.scale = tensors[2] + self.scale_t = tensors[3] + return tensors[4:] + + # Compatibility + @property + def _data(self): + return self.data + + @_data.setter + def _data(self, value): + self.data = value + + @property + def _scale_inv(self): + return self.scale + + @_scale_inv.setter + def _scale_inv(self, value): + self.scale = value def __repr__(self): return ( @@ -165,47 +277,10 @@ def __repr__(self): f"dtype={self.dtype}, " f"device={self.device}, " f"quant_dtype={self.quant_dtype}, " - f"data={self.dequantize(dtype=self.dtype)}, " f"original_shape={self.original_shape}" ")" ) - def quantize_( - self, - tensor: torch.Tensor, - *, - noop_flag: Optional[torch.Tensor] = None, - ) -> ExperimentalQuantizedTensor: - """In-place update of quantized data - - Parameters - ---------- - tensor: torch.Tensor - Tensor to copy from - noop_flag: torch.Tensor, optional - float32 flag indicating whether to avoid performing update - - """ - if isinstance(tensor, ExperimentalQuantizedTensor): - return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) - self.get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self - - def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: - """ - Construct plain PyTorch tensor from quantized tensor - """ - if dtype is None: - dtype = self.dtype - - # Ignore data_t for now - assert self.data is not None, "QuantizedTensor has no valid tensor data" - assert self.scale is not None, "QuantizedTensor has no valid scale" - tensor_data = self.data - tensor_scale = self.scale - # Dispatch to the quantizer - return self.get_quantizer().dequantize(tensor_data, tensor_scale, dtype=dtype) - def update_usage( self, rowwise_usage: Optional[bool] = None, @@ -224,10 +299,10 @@ def update_usage( # Generate data that is required if needs_data and not has_data: - raise RuntimeError("Cannot generate FP8 data, even from FP8 data transpose") + raise RuntimeError("Cannot generate FP4 data, even from FP4 data transpose") if needs_data_transpose and not has_data_transpose: if not has_data: - raise RuntimeError("FP8 data is required to generate FP8 data transpose") + raise RuntimeError("FP4 data is required to generate FP4 data transpose") self._create_transpose() # Delete data that is not required @@ -262,7 +337,7 @@ def get_wgrad_sign_vector() -> torch.Tensor: ) -class NVFP4QuantizerRef(ExperimentalQuantizer): +class NVFP4QuantizerRef(Quantizer): """NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" def __init__( @@ -277,6 +352,8 @@ def __init__( with_random_sign_mask: bool = True, ): super().__init__(rowwise=rowwise, columnwise=columnwise) + self.internal = True + self.dtype = dtype self.pow_2_scales = pow_2_scales self.eps = eps @@ -284,6 +361,11 @@ def __init__( self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask + @property + def experimental(self) -> bool: + """Flag to indicate this quantizer is using experimental Kitchen middleware""" + return True + @staticmethod def _build_hadamard_matrix( size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True @@ -500,7 +582,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ - sx: scale tensor for qx (if rowwise_usage), None otherwise - qx_t: quantized data in column-major order (if columnwise_usage), None otherwise - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise - - global_amax: global amax tensor + - global_amax_row, global_amax_col: global amax tensors """ if self.pow_2_scales: assert self.quant_tile_shape == ( @@ -607,25 +689,25 @@ def quantize( dtype=tensor.dtype, device=tensor.device, quant_dtype=self.dtype, - quantizer=self, + _quantizer=self, original_shape=original_shape, ) def update_quantized( self, src: torch.Tensor, - dst: ExperimentalQuantizedTensor, + dst: QuantizedTensorStorage, *, noop_flag: Optional[torch.Tensor] = None, - ) -> ExperimentalQuantizedTensor: + ) -> QuantizedTensorStorage: """Update the quantized tensor with the given tensor in-place Parameters ---------- src: torch.Tensor Source tensor to copy from - dst: ExperimentalQuantizedTensor - Destination ExperimentalQuantizedTensor to update + dst: QuantizedTensorStorage + Destination QuantizedTensorStorage to update noop_flag: torch.Tensor, optional float32 flag indicating whether to avoid performing update """ @@ -642,14 +724,15 @@ def update_quantized( if src.ndim > 2: src = src.view(-1, src.shape[-1]) - qx, sx, qx_t, sx_t, global_amax = self._quantize(src) + qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(src) # Update the destination with new data dst.data = qx dst.scale = sx dst.data_t = qx_t dst.scale_t = sx_t - dst.global_amax = global_amax + dst.global_amax_row = global_amax_row + dst.global_amax_col = global_amax_col dst.dtype = src.dtype dst.quant_dtype = self.dtype dst.original_shape = original_shape @@ -665,9 +748,7 @@ def supports_allgather_fp8(self) -> bool: """ return False - def transpose_qresult( - self, qresult: quantization.ExperimentalQuantizedTensor - ) -> quantization.ExperimentalQuantizedTensor: + def transpose_qresult(self, qresult: QuantizedTensorStorage) -> QuantizedTensorStorage: """Convert row-wise data to column-wise data (?) TODO(etsykunov): Confirm docstring is correct. @@ -687,17 +768,11 @@ def is_data_t_transposed_in_memory(self) -> bool: """ raise NotImplementedError("Not implemented yet") - def dequantize( - self, tensor: torch.Tensor, scale: torch.Tensor, dtype: Optional[torch.dtype] = None - ) -> torch.Tensor: - """Dequantize the quantized tensor""" - raise NotImplementedError("Not implemented yet") - def qgemm( self, qx: torch.Tensor, qw: torch.Tensor, - m_params: quantization.MMParams, + m_params: quantization.MMParams, # pylint: disable=unused-argument out_dtype: torch.dtype, sx: torch.Tensor, sw: torch.Tensor, @@ -705,9 +780,10 @@ def qgemm( out: torch.Tensor | None = None, accumulate: bool = False, gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, - qresult_x: quantization.ExperimentalQuantizedTensor | None = None, - qresult_w: quantization.ExperimentalQuantizedTensor | None = None, + qresult_x: QuantizedTensorStorage | None = None, + qresult_w: QuantizedTensorStorage | None = None, ) -> torch.Tensor: + """Python implementation of microblock FP4 GEMM.""" assert bias is None, "Bias is implemented for FP4 GEMM." high_precision_x = cast_from_fp4x2(qx, out_dtype) diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 3505a6830..6151ecafd 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -11,10 +11,8 @@ import torch from .. import cpp_extensions as tex -from .. import experimental from ..constants import TE_DType from ..export import is_in_onnx_export_mode -from ..tensor.utils import is_experimental from ..utils import get_default_init_method @@ -172,32 +170,6 @@ def noop_cat( return _NoopCatFunc.apply(dim, *tensors) -def get_module_quantizers( - module: torch.nn.Module, - fp8_output: bool, - fp8_grad: bool, - debug: bool, -): - """Return the 6-tuple of quantizers for a module in a centralized way. - - Routing policy: - - If experimental quantization is enabled via environment and module.fp8 is True, - return experimental quantizers. - - Otherwise, return the module's own quantizers (debug or regular). - """ - if getattr(module, "fp8", False) and is_experimental(): - # TODO(etsykunov): Quantizer instantiation should be better - # done in the module's constructor - qlinear_params = experimental.config.set_qlinear_params() - - if qlinear_params is not None: - return experimental.config.get_experimental_quantizers(module.fp8, qlinear_params) - - if not debug: - return module._get_quantizers(fp8_output, fp8_grad) - return module._get_debug_quantizers(fp8_output, fp8_grad) - - @dataclasses.dataclass class _ParameterInitMeta: """ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0559ae7ce..824fcc0a7 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -55,7 +55,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ._common import apply_normalization, noop_cat, WeightGradStore, get_module_quantizers +from ._common import apply_normalization, noop_cat, WeightGradStore from ..tensor.quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, @@ -1541,7 +1541,11 @@ def forward( # Get concatenated weight and bias tensors weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) if debug: if self.no_debug_features_active(quantizers): debug = False diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 67124c157..12b7bac01 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -25,7 +25,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ._common import noop_cat, WeightGradStore, get_module_quantizers +from ._common import noop_cat, WeightGradStore from ..fp8 import FP8GlobalStateManager from ..utils import ( cast_if_needed, @@ -1428,7 +1428,11 @@ def forward( weight_tensor, bias_tensor = self._get_weight_and_bias_tensors() - quantizers = get_module_quantizers(self, fp8_output, fp8_grad, debug) + quantizers = ( + self._get_quantizers(fp8_output, fp8_grad) + if not debug + else self._get_debug_quantizers(fp8_output, fp8_grad) + ) if debug: if self.no_debug_features_active(quantizers): debug = False From ca6fedcfd128ec6349a1fc481e007eafea531e1b Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Tue, 14 Oct 2025 10:38:01 -0700 Subject: [PATCH 054/141] [JAX] Add BRCM support for THD (#2242) * Add BRCM support when creating a test mask for fused attn Signed-off-by: Kshitij Lakhani * Add support for BRCM to correctly generate the mask needed for calculating the seqlens and offsets for THD Signed-off-by: Kshitij Lakhani * Skip drop=0 and no_bias case for BRCM as cuDNN does not suport this Signed-off-by: Kshitij Lakhani * Skip BRCM test cases where max_seqlen_q > max_seqlen_kv Signed-off-by: Kshitij Lakhani * Refactor the segment id run length code for BRCM seqoffset and seqlens calculations Signed-off-by: Kshitij Lakhani * Fix the drop inequality skip condition in fused attn Signed-off-by: Kshitij Lakhani * nit: Adjust the BRCM id name in the test to make it consistent Signed-off-by: Kshitij Lakhani * Fix the brcm mask condition. Fix the condition for cross atnn type pattern to only apply for brcm Change the num segments per sequence to 3 instead of 2 Reduce one test pattern data size and make it such that it triggers brcm Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix lint errors Signed-off-by: Kshitij Lakhani * Fix incorrectly changed dtype to numpy bool_ rather than native python bool Signed-off-by: Kshitij Lakhani * Restore the numsegments to earlier value Signed-off-by: Kshitij Lakhani * Add example for THD BRCM Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 49 +++++++++++-- transformer_engine/jax/attention.py | 109 +++++++++++++++++++++++++++- 2 files changed, 148 insertions(+), 10 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 87dfc113c..710cc134b 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -32,6 +32,7 @@ reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, + run_length_fill, make_swa_mask, SequenceDescriptor, CPStrategy, @@ -172,15 +173,34 @@ def make_mask( jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape ) - # causal mask - if attn_mask_type.is_causal(): + if attn_mask_type.is_bottom_right(): + run_length_out_q = run_length_fill(segment_ids_q) + run_length_out_kv = run_length_fill(segment_ids_kv) + bottom_right_causal_mask = make_attention_mask( + run_length_out_q - segment_pos_q, + run_length_out_kv - segment_pos_kv, + jnp.less_equal, + ) + inv_mask = combine_masks(bottom_right_causal_mask, inv_mask) + elif attn_mask_type.is_causal(): inv_causal_mask = make_attention_mask( segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) inv_mask = combine_masks(inv_causal_mask, inv_mask) # sliding window mask - inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_) + inv_swa_mask = ( + make_swa_mask( + segment_pos_q, + segment_pos_kv, + window_size, + dtype=jnp.bool, + segment_ids_q=segment_ids_q, + segment_ids_kv=segment_ids_kv, + ) + if attn_mask_type.is_bottom_right() + else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_) + ) inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -338,6 +358,16 @@ def _check_configs(self): if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") + if self.attn_mask_type.is_bottom_right(): + if self.max_seqlen_q > self.max_seqlen_kv: + pytest.skip( + f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q" + ) + if self.attn_bias_type is not AttnBiasType.NO_BIAS: + pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM") + if self.dropout_prob != 0.0: + pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM") + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") @@ -526,7 +556,11 @@ def generate_random_segment_ids( self.pad_kv = self.pad_q else: # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support - min_segment_len = None if self.window_size is None else self.seqlens_q + min_segment_len = None + if ( + self.window_size is not None or self.attn_mask_type.is_bottom_right() + ): # SWA or BRCM requires kv_len >= q_len + min_segment_len = self.seqlens_q self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, @@ -937,6 +971,9 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + pytest.param( + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" + ), ], ) @pytest.mark.parametrize( @@ -958,14 +995,14 @@ def check_dqkv(primitive, reference, pad, idx): ), pytest.param( 2, - 2048, + 512, 1024, 12, 12, 64, 64, jnp.bfloat16, - id="2-2048-1024-12-12-64-64-BF16-CROSS", + id="2-512-1024-12-12-64-64-BF16-CROSS", ), pytest.param( 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 093146162..1ce44a2b9 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -209,6 +209,8 @@ def make_swa_mask( segment_pos_kv: jnp.ndarray, window_size: Optional[Tuple[int, int]] = None, dtype: jax.typing.DTypeLike = jnp.float32, + segment_ids_q: jnp.ndarray = None, + segment_ids_kv: jnp.ndarray = None, ): """ Generate a sliding window mask (1 = attend, 0 = masked). @@ -227,6 +229,10 @@ def make_swa_mask( Defaults to None. dtype (jax.typing.DTypeLike, optional): Mask data type. Defaults to jnp.float32. + segment_ids_q (jnp.ndarray): + Query segment id that each token belongs to + segment_ids_kv (jnp.ndarray): + Key/value segment id that each token belongs to Returns: jnp.ndarray: @@ -240,6 +246,18 @@ def make_swa_mask( right_window = jnp.inf if right_window < 0 else right_window pos_q = jnp.expand_dims(segment_pos_q, axis=-1) pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2) + # For Bottom Right Causal Mask (BRCM) + if segment_ids_q is not None and segment_ids_kv is not None: + run_length_q = run_length_fill(segment_ids_q) + run_length_kv = run_length_fill(segment_ids_kv) + run_length_q_exp = jnp.expand_dims(run_length_q, axis=-1) + run_length_kv_exp = jnp.expand_dims(run_length_kv, axis=-2) + bottom_right_inv_swa_mask = ( + run_length_q_exp - pos_q + left_window >= run_length_kv_exp - pos_kv + ) + bottom_right_inv_swa_mask = jnp.expand_dims(bottom_right_inv_swa_mask, axis=-3) + return bottom_right_inv_swa_mask.astype(dtype) + # All other cases other than BRCM inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window) inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3) return inv_swa_mask.astype(dtype) @@ -420,6 +438,42 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path( ) +def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray: + """ + Returns an array of run-lengths of the flattened segment ids + """ + # Example for run_length_fill_flattened: + # Input segment_ids_flattened: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]] + # run_ids: [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]] + # counts: [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]] + # Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]] + boundary = jnp.concatenate( + [jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]] + ) + run_ids = jnp.cumsum(boundary) - 1 + # Each element could, in worst case, start a run + max_runs = segment_ids_flattened.shape[-1] + counts = jnp.bincount(run_ids, length=max_runs) + # Fill in the missing values + segment_ids_run_length_1d = counts[run_ids] + segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d) + return segment_ids_run_length_1d + + +def run_length_fill(segment_ids) -> jnp.ndarray: + """ + Returns an array of run-lengths of the segment ids, with shape preserved + """ + # Example for run_length_fill: + # Input segment_ids: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]] + # Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]] + # Flatten all dimension except the last one prior to executing vmap run length + orig_shape = segment_ids.shape + segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1]) + run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) + return run_length_segment_id_shape.reshape(orig_shape) + + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, segment_ids_kv, @@ -443,7 +497,10 @@ def _segment_ids_pos_to_seqlens_offsets( # # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # examine only O(Q+KV) elements. - if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1): + # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well + if (attn_mask_type.is_causal() and window_size is None) or ( + window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + ): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) @@ -459,8 +516,41 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, lambda x, y: jnp.equal(x, y) * x, ) + # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied attn_mask = segment_mask - if attn_mask_type.is_causal(): + if attn_mask_type.is_bottom_right(): + run_length_out_q = run_length_fill(segment_ids_q) + run_length_out_kv = run_length_fill(segment_ids_kv) + # Example for brcm: + # run_length_out_q: [3 3 3 0 4 4 4 4] + # segment_pos_q: [0 1 2 3 0 1 2 3] + # segment_ids_q: [1 1 1 0 2 2 2 2] + # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] + # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] + # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] + # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] + # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] + # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] + # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] + # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] + # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] + # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] + # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] + # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] + # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] + # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] + # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] + bottom_right_causal_mask = make_attention_mask( + run_length_out_q - segment_pos_q, + run_length_out_kv - segment_pos_kv, + jnp.less_equal, + ) + attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) + elif attn_mask_type.is_causal(): causal_mask = make_attention_mask( segment_pos_q, segment_pos_kv, @@ -468,7 +558,19 @@ def _segment_ids_pos_to_seqlens_offsets( ) attn_mask = jnp.logical_and(segment_mask, causal_mask) - swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + # TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets + swa_mask = ( + make_swa_mask( + segment_pos_q, + segment_pos_kv, + window_size, + dtype=jnp.bool, + segment_ids_q=segment_ids_q, + segment_ids_kv=segment_ids_kv, + ) + if attn_mask_type.is_bottom_right() + else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + ) attn_mask = jnp.logical_and(attn_mask, swa_mask) attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) @@ -1125,5 +1227,4 @@ def fused_attn( context_parallel_axis=context_parallel_axis, context_checkpoint_name=context_checkpoint_name, ) - return output From 85a919973d062063b86000856b5c0f258a27380d Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 14 Oct 2025 14:57:00 -0700 Subject: [PATCH 055/141] Generalize quantization APIs for FP8/FP4/.. recipes (#2256) * Initial API change Signed-off-by: Kirthi Shankar Sivamani * Change all imports and api Signed-off-by: Kirthi Shankar Sivamani * format Signed-off-by: Kirthi Shankar Sivamani * fixes Signed-off-by: Kirthi Shankar Sivamani * fix typo Signed-off-by: Kirthi Shankar Sivamani * fix recipe tets Signed-off-by: Kirthi Shankar Sivamani * fix more tests Signed-off-by: Kirthi Shankar Sivamani * fix docs, tests, and make Jax change as well Signed-off-by: Kirthi Shankar Sivamani * Change internal uses of fp8_autocast Signed-off-by: Kirthi Shankar Sivamani * Address nits Signed-off-by: Kirthi Shankar Sivamani * rename file Signed-off-by: Kirthi Shankar Sivamani * CG function, and small test fixes Signed-off-by: Kirthi Shankar Sivamani * Change instances of make_graphed_callables internally Signed-off-by: Kirthi Shankar Sivamani * Fix distributed tests Signed-off-by: Kirthi Shankar Sivamani * Review Signed-off-by: Kirthi Shankar Sivamani * Review Signed-off-by: Kirthi Shankar Sivamani * Fix test and add more docs Signed-off-by: Kirthi Shankar Sivamani * Cleanup test imports and minimize internal file imports Signed-off-by: Kirthi Shankar Sivamani * Make is_bf16_available public Signed-off-by: Kirthi Shankar Sivamani * fixes Signed-off-by: Kirthi Shankar Sivamani * fix tests Signed-off-by: Kirthi Shankar Sivamani * Better docs and better api Signed-off-by: Kirthi Shankar Sivamani * format Signed-off-by: Kirthi Shankar Sivamani * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * fix nvfp4 test Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- README.rst | 4 +- benchmarks/linear/benchmark_grouped_linear.py | 7 +- docs/api/jax.rst | 1 + docs/api/pytorch.rst | 22 +- docs/debug/1_getting_started.rst | 2 +- docs/examples/advanced_optimizations.ipynb | 26 +- docs/examples/fp8_primer.ipynb | 20 +- docs/examples/onnx/onnx_export.ipynb | 4 +- docs/examples/quickstart.ipynb | 6 +- docs/examples/quickstart_utils.py | 10 +- docs/examples/te_gemma/te_gemma.py | 10 +- .../te_gemma/te_gemma_loading_weights.py | 8 +- .../tutorial_generation_gemma_with_te.ipynb | 32 +- docs/examples/te_gemma/utils.py | 2 +- docs/faq.rst | 4 +- .../jax/collective_gemm/test_dense_grad.py | 8 +- examples/jax/collective_gemm/test_gemm.py | 6 +- .../test_layernorm_mlp_grad.py | 8 +- examples/jax/encoder/README.md | 6 +- .../encoder/test_model_parallel_encoder.py | 6 +- examples/jax/encoder/test_multigpu_encoder.py | 8 +- .../encoder/test_multiprocessing_encoder.py | 6 +- .../jax/encoder/test_single_gpu_encoder.py | 4 +- examples/jax/mnist/README.md | 4 +- examples/jax/mnist/test_single_gpu_mnist.py | 4 +- .../te_layer_with_overlap.py | 4 +- examples/pytorch/fsdp/README.md | 2 +- examples/pytorch/fsdp/fsdp.py | 8 +- examples/pytorch/mnist/main.py | 6 +- tests/jax/test_distributed_dense.py | 6 +- tests/jax/test_distributed_helper.py | 6 +- tests/jax/test_distributed_layernorm.py | 6 +- tests/jax/test_distributed_layernorm_mlp.py | 18 +- tests/jax/test_distributed_softmax.py | 4 +- tests/jax/test_fused_attn.py | 10 +- tests/jax/test_helper.py | 36 +- tests/jax/test_layer.py | 10 +- .../attention/run_attention_with_cp.py | 11 +- tests/pytorch/attention/test_attention.py | 43 +- .../attention/test_attention_with_cp.py | 2 +- tests/pytorch/attention/test_cp_utils.py | 1 - tests/pytorch/attention/test_kv_cache.py | 26 +- tests/pytorch/debug/run_distributed.py | 12 +- tests/pytorch/debug/test_api_features.py | 2 +- tests/pytorch/debug/test_config.py | 2 +- tests/pytorch/debug/test_log.py | 20 +- tests/pytorch/debug/test_numerics.py | 12 +- tests/pytorch/debug/test_sanity.py | 5 +- .../run_cast_master_weights_to_fp8.py | 36 +- tests/pytorch/distributed/run_fsdp2_model.py | 4 +- .../distributed/run_gemm_with_overlap.py | 11 +- .../distributed/run_layer_with_overlap.py | 10 +- tests/pytorch/distributed/run_numerics.py | 15 +- .../pytorch/distributed/run_numerics_exact.py | 13 +- .../test_cast_master_weights_to_fp8.py | 8 +- .../distributed/test_comm_gemm_overlap.py | 5 +- tests/pytorch/distributed/test_fusible_ops.py | 37 +- .../test_fusible_ops_with_userbuffers.py | 20 +- tests/pytorch/distributed/test_numerics.py | 12 +- .../distributed/test_numerics_exact.py | 12 +- tests/pytorch/distributed/test_sanity.py | 3 +- tests/pytorch/distributed/test_torch_fsdp2.py | 7 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 7 +- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 36 +- .../nvfp4/test_nvfp4_quantize_exact.py | 12 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 13 +- tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py | 6 +- tests/pytorch/test_checkpoint.py | 14 +- tests/pytorch/test_cpu_offloading.py | 13 +- tests/pytorch/test_cuda_graphs.py | 37 +- tests/pytorch/test_custom_recipe.py | 40 +- tests/pytorch/test_deferred_init.py | 1 - .../test_float8_blockwise_gemm_exact.py | 10 +- .../test_float8_blockwise_scaling_exact.py | 11 +- .../test_float8_current_scaling_exact.py | 14 +- tests/pytorch/test_float8blockwisetensor.py | 5 +- tests/pytorch/test_float8tensor.py | 6 +- tests/pytorch/test_fused_optimizer.py | 33 +- tests/pytorch/test_fusible_ops.py | 90 +- tests/pytorch/test_hf_integration.py | 2 +- tests/pytorch/test_multi_tensor.py | 2 +- tests/pytorch/test_numerics.py | 65 +- tests/pytorch/test_onnx_export.py | 26 +- tests/pytorch/test_parallel_cross_entropy.py | 2 +- tests/pytorch/test_permutation.py | 19 +- tests/pytorch/test_recipe.py | 59 +- tests/pytorch/test_sanity.py | 67 +- tests/pytorch/utils.py | 6 +- transformer_engine/common/recipe/__init__.py | 4 +- .../debug/features/fake_quant.py | 2 +- transformer_engine/jax/__init__.py | 3 +- transformer_engine/jax/flax/transformer.py | 2 +- transformer_engine/jax/quantize/helper.py | 111 +- transformer_engine/pytorch/__init__.py | 17 +- .../dot_product_attention/backends.py | 4 +- .../dot_product_attention/context_parallel.py | 6 +- .../dot_product_attention.py | 24 +- .../attention/dot_product_attention/utils.py | 4 +- .../pytorch/attention/multi_head_attention.py | 4 +- transformer_engine/pytorch/distributed.py | 14 +- transformer_engine/pytorch/fp8.py | 1280 +-------------- transformer_engine/pytorch/graph.py | 189 ++- transformer_engine/pytorch/module/base.py | 10 +- .../pytorch/module/fp8_padding.py | 2 +- .../pytorch/module/fp8_unpadding.py | 2 +- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/ops/_common.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 6 +- .../pytorch/ops/basic/quantize.py | 6 +- .../ops/fused/backward_activation_bias.py | 2 +- .../fused/forward_linear_bias_activation.py | 2 +- .../ops/fused/forward_linear_bias_add.py | 2 +- .../ops/fused/forward_linear_scale_add.py | 2 +- .../ops/fused/userbuffers_forward_linear.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 2 +- transformer_engine/pytorch/ops/op.py | 6 +- transformer_engine/pytorch/quantization.py | 1396 +++++++++++++++++ transformer_engine/pytorch/utils.py | 28 +- 121 files changed, 2385 insertions(+), 2016 deletions(-) create mode 100644 transformer_engine/pytorch/quantization.py diff --git a/README.rst b/README.rst index 19ab1a7d9..380a99edf 100644 --- a/README.rst +++ b/README.rst @@ -86,7 +86,7 @@ PyTorch fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) # Enable autocasting for the forward pass - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.autocast(enabled=True, recipe=fp8_recipe): out = model(inp) loss = out.sum() @@ -121,7 +121,7 @@ Flax fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID) # Enable autocasting for the forward pass - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.autocast(enabled=True, recipe=fp8_recipe): model = te_flax.DenseGeneral(features=HIDDEN) def loss_fn(params, other_vars, inp): diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 44f1c8967..48adb2a10 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -6,11 +6,10 @@ import torch import torch.utils.benchmark as benchmark import pandas as pd -import pathlib from transformer_engine.pytorch.module import GroupedLinear from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager from contextlib import nullcontext """ @@ -51,9 +50,7 @@ def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): assert mode in ["fwd_only", "fwd_bwd"] - fp8_context = ( - fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() - ) + fp8_context = autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() # print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") if mode == "fwd_only": diff --git a/docs/api/jax.rst b/docs/api/jax.rst index 1af5cd1d0..789b27e59 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -30,6 +30,7 @@ Modules .. autoapifunction:: transformer_engine.jax.fp8_autocast +.. autoapifunction:: transformer_engine.jax.autocast .. autoapifunction:: transformer_engine.jax.update_collections diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 04b49fac2..c456f1a6a 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -41,8 +41,28 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.fp8_model_init +.. autoapifunction:: transformer_engine.pytorch.autocast + +.. autoapifunction:: transformer_engine.pytorch.quantized_model_init + .. autoapifunction:: transformer_engine.pytorch.checkpoint +.. autoapifunction:: transformer_engine.pytorch.is_fp8_available + +.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available + +.. autoapifunction:: transformer_engine.pytorch.is_fp8_block_scaling_available + +.. autoapifunction:: transformer_engine.pytorch.is_nvfp4_available + +.. autoapifunction:: transformer_engine.pytorch.is_bf16_available + +.. autoapifunction:: transformer_engine.pytorch.get_cudnn_version + +.. autoapifunction:: transformer_engine.pytorch.get_device_compute_capability + +.. autoapifunction:: transformer_engine.pytorch.get_default_recipe + .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context @@ -64,4 +84,4 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.destroy_ub .. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode - :members: FP8, NONE \ No newline at end of file + :members: FP8, NONE diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index a7b86dad3..906c62556 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -69,7 +69,7 @@ Let's look at a simple example of training a Transformer layer using Transformer for epoch in range(5): transformer_layer.train() optimizer.zero_grad() - with te.fp8_autocast(enabled=True): + with te.autocast(enabled=True): output = transformer_layer(dummy_input) loss = criterion(output, dummy_target) loss.backward() diff --git a/docs/examples/advanced_optimizations.ipynb b/docs/examples/advanced_optimizations.ipynb index 3d889859b..5dc9cb92f 100644 --- a/docs/examples/advanced_optimizations.ipynb +++ b/docs/examples/advanced_optimizations.ipynb @@ -71,7 +71,7 @@ " amax_compute_algo=\"max\",\n", ")\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = basic_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "\n", @@ -81,7 +81,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] }, @@ -135,7 +135,7 @@ "\n", "Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n", "\n", - "One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager." + "One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager." ] }, { @@ -169,7 +169,7 @@ ")\n", "\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=world_group):\n", " y = parallel_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "\n", @@ -179,10 +179,10 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = {\n", + " autocast_kwargs = {\n", " \"enabled\": True,\n", - " \"fp8_recipe\": fp8_recipe,\n", - " \"fp8_group\": world_group,\n", + " \"recipe\": fp8_recipe,\n", + " \"amax_reduction_group\": world_group,\n", " },\n", ")" ] @@ -234,7 +234,7 @@ " param.main_grad = torch.zeros_like(param, dtype=torch.float32)\n", "\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = wgrad_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "for param in wgrad_transformer.parameters():\n", @@ -248,7 +248,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] }, @@ -268,7 +268,7 @@ "\n", "\n", "\n", - "Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n", + "Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n", "\n", "
\n", "\n", @@ -303,12 +303,12 @@ "weight_caching_transformer.to(dtype=dtype).cuda()\n", "\n", "# Cast weights in first gradient accumulation step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)\n", "y.backward(dy)\n", "\n", "# Reuse FP8 weights in subsequent gradient accumulation steps\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)\n", "y.backward(dy)\n", "\n", @@ -318,7 +318,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None, \"is_first_microbatch\": False },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] } diff --git a/docs/examples/fp8_primer.ipynb b/docs/examples/fp8_primer.ipynb index a8ebd770c..457d13921 100644 --- a/docs/examples/fp8_primer.ipynb +++ b/docs/examples/fp8_primer.ipynb @@ -132,7 +132,7 @@ " - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n", " - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n", "disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n", - " - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `fp8_autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n", + " - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n", "\n", "The full linear layer utilizing NVFP4 is presented in Figure 9.\n", "\n", @@ -193,7 +193,7 @@ "source": [ "### FP8 autocasting\n", "\n", - "Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager." + "Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager." ] }, { @@ -212,7 +212,7 @@ "\n", "inp = torch.rand((1024, 768)).cuda()\n", "\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " out_fp8 = my_linear(inp)" ] }, @@ -221,7 +221,7 @@ "id": "e41161f1", "metadata": {}, "source": [ - "The `fp8_autocast` context manager hides the complexity of handling FP8:\n", + "The `autocast` context manager hides the complexity of handling FP8:\n", "\n", "- All FP8-safe operations have their inputs cast to FP8\n", "- Amax history is updated\n", @@ -243,9 +243,9 @@ "source": [ "### Handling backward pass\n", "\n", - "When a model is run inside the `fp8_autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `fp8_autocast` context manager aggregates the tensors before performing the communication.\n", + "When a model is run inside the `autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `autocast` context manager aggregates the tensors before performing the communication.\n", "\n", - "Due to this aggregation the backward call needs to happen outside of the `fp8_autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass." + "Due to this aggregation the backward call needs to happen outside of the `autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass." ] }, { @@ -257,11 +257,11 @@ "source": [ "loss_fp8 = out_fp8.mean()\n", "\n", - "loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast\n", + "loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside autocast\n", "\n", "out_fp32 = my_linear(inp)\n", "loss_fp32 = out_fp32.mean()\n", - "loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside fp8_autocast" + "loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside autocast" ] }, { @@ -451,9 +451,9 @@ "\n", "inp = inp.bfloat16()\n", "\n", - "with te.fp8_autocast(fp8_recipe=nvfp4_recipe):\n", + "with te.autocast(recipe=nvfp4_recipe):\n", " y = my_linear1(inp)\n", - " with te.fp8_autocast(fp8_recipe=mxfp8_recipe):\n", + " with te.autocast(recipe=mxfp8_recipe):\n", " out = my_linear2(y)\n", "\n", "print(out)\n", diff --git a/docs/examples/onnx/onnx_export.ipynb b/docs/examples/onnx/onnx_export.ipynb index 26ac71188..5ffc91854 100644 --- a/docs/examples/onnx/onnx_export.ipynb +++ b/docs/examples/onnx/onnx_export.ipynb @@ -80,7 +80,7 @@ "model = Model().eval().cuda()\n", "inps = (torch.randn([S, B, H], device=\"cuda\"),)\n", "def _inference(fp8_enabled):\n", - " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n", + " with torch.no_grad(), te.pytorch.autocast(enabled=fp8_enabled):\n", " model(*inps)\n", "\n", "te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n", @@ -138,7 +138,7 @@ "from transformer_engine.pytorch.export import te_translation_table\n", "\n", "def export(model, fname, inputs, fp8=True):\n", - " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n", + " with torch.no_grad(), te.pytorch.autocast(enabled=fp8):\n", " # ! IMPORTANT !\n", " # Transformer Engine models must have warm-up run\n", " # before export. FP8 recipe during warm-up should \n", diff --git a/docs/examples/quickstart.ipynb b/docs/examples/quickstart.ipynb index 3b50ce161..0ad2f4fee 100644 --- a/docs/examples/quickstart.ipynb +++ b/docs/examples/quickstart.ipynb @@ -548,7 +548,7 @@ "\n", "
\n", "\n", - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." + "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." ] }, { @@ -567,7 +567,7 @@ "fp8_format = Format.HYBRID\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "torch.manual_seed(1234)\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n", " y = te_transformer(x, attention_mask=None)" ] }, @@ -591,7 +591,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] } diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index f7a81d4d8..473fce7fe 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -13,7 +13,7 @@ def speedometer( input: torch.Tensor, output_grad: torch.Tensor, forward_kwargs: dict = {}, - fp8_autocast_kwargs: Optional[dict] = None, + autocast_kwargs: Optional[dict] = None, timing_iters: int = 50, warmup_iters: int = 50, ) -> None: @@ -23,20 +23,20 @@ def speedometer( """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - if fp8_autocast_kwargs is None: - fp8_autocast_kwargs = {"enabled": False} + if autocast_kwargs is None: + autocast_kwargs = {"enabled": False} # Warmup runs torch.cuda.synchronize() for _ in range(warmup_iters): - with te.fp8_autocast(**fp8_autocast_kwargs): + with te.autocast(**autocast_kwargs): output = module(input, **forward_kwargs) output.backward(output_grad) # Timing runs start.record() for _ in range(timing_iters): - with te.fp8_autocast(**fp8_autocast_kwargs): + with te.autocast(**autocast_kwargs): output = module(input, **forward_kwargs) output.backward(output_grad) end.record() diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 6285fea1a..d3de8a185 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -14,7 +14,7 @@ import transformer_engine as te from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding from transformer_engine.common.recipe import Format, DelayedScaling -from transformer_engine.pytorch.fp8 import get_default_fp8_recipe +from transformer_engine.pytorch.quantization import get_default_fp8_recipe import transformers from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel @@ -461,8 +461,8 @@ def generate( # Both autocasts are needed: FP8 for operations that can run in lower # precision and BF16 for those that cannot. - with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( - enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.autocast( + enabled=self.config.fp8, recipe=self.fp8_recipe if self.config.fp8 else None ): lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # If padding is at the beginning, then shift it to the end @@ -694,8 +694,8 @@ def record_graph(self, function, input_tensor, **sample_kwargs): graphed_function = te.pytorch.make_graphed_callables( function, (input_tensor,), - fp8_enabled=self.config.fp8, - fp8_recipe=fp8_recipe, + enabled=self.config.fp8, + recipe=fp8_recipe, allow_unused_input=True, num_warmup_iters=5, sample_kwargs=sample_kwargs, diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py index d0df9edc5..36b0a5b73 100755 --- a/docs/examples/te_gemma/te_gemma_loading_weights.py +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -9,7 +9,7 @@ from typing import List -from transformer_engine.pytorch.fp8 import fp8_model_init +from transformer_engine.pytorch.quantization import quantized_model_init from transformers.modeling_utils import load_state_dict from transformers.utils.hub import get_checkpoint_shard_files @@ -88,10 +88,10 @@ def load_te_model(cls, config): config.use_cache = False # To make TransformerLayer compatible with GemmaModel # Loading model with FP8 only weights needs both the following context managers. - # 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights. + # 1. quantized_model_init(config.quantized_model_init) to tell TE to use FP8 only weights. # 2. torch.no_grad() during TE modules' initilization so that they respect - # the `fp8_model_init` context manager. - with torch.no_grad(), fp8_model_init(config.fp8_model_init): + # the `quantized_model_init` context manager. + with torch.no_grad(), quantized_model_init(config.quantized_model_init): # Just create a model with random weights. vanilla_model = cls(config).cuda() diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index cc8675cfd..c31e272b2 100755 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -77,7 +77,7 @@ "\n", "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n", "\n", - "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", + "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", "\n", "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", "\n", @@ -94,12 +94,12 @@ "\n", "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n", "\n", - "The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", + "The Transformer Engine includes a wrapper `quantized_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", "\n", "
\n", "\"\"\n", "
\n", - "Figure 3: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using fp8_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", + "Figure 3: Model under autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using quantized_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", "
\n", "
\n", "\n", @@ -405,8 +405,8 @@ " graphed_function = te.pytorch.make_graphed_callables(\n", " function,\n", " (input_tensor,),\n", - " fp8_enabled=self.config.fp8,\n", - " fp8_recipe=fp8_recipe,\n", + " enabled=self.config.fp8,\n", + " recipe=fp8_recipe,\n", " allow_unused_input=True,\n", " num_warmup_iters=5,\n", " sample_kwargs=sample_kwargs,\n", @@ -540,14 +540,14 @@ "source": [ "### Calibrating FP8 scaling factors for correctness\n", "\n", - "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", + "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", "\n", "1. Model weight tensors\n", "2. Input tensors\n", "\n", "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n", "\n", - "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", + "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", "\n", "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n", " \n", @@ -590,14 +590,14 @@ "model = init_te_gemma_model(run_config)\n", "\n", "# Calibration\n", - "with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n", + "with te.autocast(enabled=False, calibrating=True), torch.autocast(\n", " device_type=\"cuda\", dtype=torch.bfloat16\n", "):\n", " model.train()\n", " run_forward_pass(model, run_config, num_iters=64)\n", "\n", "# Compute scale_fwd with enabled fp8 autocast\n", - "with te.fp8_autocast(enabled=True), torch.autocast(\n", + "with te.autocast(enabled=True), torch.autocast(\n", " device_type=\"cuda\", dtype=torch.bfloat16\n", "):\n", " run_forward_pass(model, run_config, 1)\n", @@ -734,7 +734,7 @@ "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n", "\n", "\n", - "Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:" + "Transformer Engine supports maintaining FP8-only weights with the `quantized_model_init` context manager. Let's see a small example:" ] }, { @@ -778,7 +778,7 @@ "del linear_bf16\n", "\n", "# Initialize model weights in FP8 precision\n", - "with torch.no_grad(), te.fp8_model_init(enabled=True):\n", + "with torch.no_grad(), te.quantized_model_init(enabled=True):\n", " linear_fp8 = te.Linear(H, D)\n", "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", "del linear_fp8" @@ -793,11 +793,11 @@ "
\n", "\n", "
\n", - " Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", + " Figure 8: Using quantized_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", "
\n", "
\n", "\n", - "Let's run the code with `fp8_model_init`:" + "Let's run the code with `quantized_model_init`:" ] }, { @@ -862,7 +862,7 @@ "\n", "# Enable FP8 math and FP8 model weights\n", "run_config.fp8 = True\n", - "run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "run_config.quantized_model_init = True # This will result in storing only fp8 weights.\n", "run_config.fp8_model_weights_filename = (\n", " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", ")\n", @@ -885,7 +885,7 @@ "| HF (baseline) | 46.6 s | - | - | - |\n", "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n", - "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" + "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `quantized_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" ] }, { @@ -911,7 +911,7 @@ "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n", "\n", "1. Longer context lengths (with paged KV cache) \n", - "2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n", + "2. Using less memory during generation (by storing weights in FP8 precision using `quantized_model_init`)\n", "\n", "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models." ] diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index cc31afc65..9b67f178f 100755 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -34,7 +34,7 @@ def __init__(self): # FP8 precision settings self.fp8 = False self.fp8_model_weights_filename = None - self.fp8_model_init = False + self.quantized_model_init = False # Cuda graphs self.generation_cuda_graphs = False diff --git a/docs/faq.rst b/docs/faq.rst index 2f9cbd272..a9406ed45 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -15,8 +15,8 @@ Here, we take the `MultiheadAttention` module as an example. Its FP8 attention m .. code-block:: python - >>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init - >>> with fp8_model_init(enabled=True): + >>> from transformer_engine.pytorch import MultiheadAttention, quantized_model_init + >>> with quantized_model_init(enabled=True): ... mha = MultiheadAttention( ... hidden_size=1024, ... num_attention_heads=16, diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py index df2dd5618..e14329d48 100644 --- a/examples/jax/collective_gemm/test_dense_grad.py +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -24,7 +24,7 @@ from transformer_engine.jax.dense import dense -from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.quantize import autocast from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOp, CollectiveOpSet, @@ -98,12 +98,12 @@ def run_dense_grad_tests(args, mesh=None): ) collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) - with mesh, fp8_autocast( + with mesh, autocast( enabled=False, - fp8_recipe=None, + recipe=None, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): - # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py index 307e4444e..ac86c551d 100644 --- a/examples/jax/collective_gemm/test_gemm.py +++ b/examples/jax/collective_gemm/test_gemm.py @@ -32,7 +32,7 @@ ) import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.quantize import autocast from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp from transformer_engine.jax.sharding import MeshResource @@ -109,9 +109,9 @@ def run_gemm_tests(args, mesh=None): else CollectiveOp.REDUCE_SCATTER ) - with mesh, fp8_autocast( + with mesh, autocast( enabled=False, - fp8_recipe=None, + recipe=None, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): print(f"Device mesh: {mesh}") diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py index 7bd6eb6a3..407cec68a 100644 --- a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -24,7 +24,7 @@ from transformer_engine.jax.layernorm_mlp import layernorm_mlp -from transformer_engine.jax.quantize import fp8_autocast +from transformer_engine.jax.quantize import autocast from transformer_engine.jax.cpp_extensions.gemm import ( CollectiveOpSet, CollectiveOp, @@ -151,12 +151,12 @@ def run_layernorm_mlp_grad_tests(args, mesh=None): collective_op_sets = (collective_op_set_1, collective_op_set_2) noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) - with mesh, fp8_autocast( + with mesh, autocast( enabled=False, - fp8_recipe=None, + recipe=None, mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), ): - # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) diff --git a/examples/jax/encoder/README.md b/examples/jax/encoder/README.md index 575f7be6e..fc6696317 100644 --- a/examples/jax/encoder/README.md +++ b/examples/jax/encoder/README.md @@ -8,7 +8,7 @@ This example uses Transformer Encoder to demonstrate the Transformer Engine usag 2. Define model: The `Net` class is a small Transformer Encoder model for sentence classification. The Transformer Engine provides `te.TransformerLayer` as encoder block and `te.DenseGeneral`. The structure of encoder block can be referred to [Scaling Up Models and Data with t5x and seqio](https://arxiv.org/abs/2203.17189) -3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`. +3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`. 4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. @@ -29,7 +29,7 @@ python test_single_gpu_encoder.py --use-fp8 3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case. -4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. +4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. 5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. @@ -136,4 +136,4 @@ numactl --cpunodebind=112 --membind=7 python test_multiprocessing_encoder.py --n numactl --cpunodebind=113 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 5 & numactl --cpunodebind=80 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 6 & numactl --cpunodebind=81 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 7 & -``` \ No newline at end of file +``` diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 5fc7efbba..7807d1fd9 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -269,9 +269,9 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, te.fp8_autocast( + ) as mesh, te.autocast( enabled=args.use_fp8, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, mesh_resource=te.MeshResource( dp_resource=DEVICE_DP_AXIS, tpsp_resource=DEVICE_TP_AXIS, @@ -287,7 +287,7 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 68fb3ddd3..8ea1dcde3 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -264,11 +264,9 @@ def train_and_evaluate(args): fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu,)) - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS,) - ) as mesh, te.fp8_autocast( + with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh, te.autocast( enabled=args.use_fp8, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS), ): @@ -282,7 +280,7 @@ def train_and_evaluate(args): mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast + # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside autocast sharding_rules = te_flax.extend_logical_axis_rules(tuple()) with flax.linen.logical_axis_rules(sharding_rules): encoder = Net(num_embed) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 358fbca4b..7e708466c 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -393,9 +393,9 @@ def train_and_evaluate(args): device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, te.fp8_autocast( + ) as mesh, te.autocast( enabled=args.use_fp8, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, mesh_resource=te.MeshResource( dp_resource=DEVICE_DP_AXIS, tpsp_resource=DEVICE_TP_AXIS, @@ -413,7 +413,7 @@ def train_and_evaluate(args): # Create custom Flax logical axis rules for sharding. customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - # Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast. + # Extend the logical axis rules with TE's rules. This must be done inside autocast. sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) with flax.linen.logical_axis_rules(sharding_rules): diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 320483099..79178485c 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -227,8 +227,8 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast( - enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + with te.autocast( + enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() ): encoder = Net(num_embed) # We use nn.Embed, thus inputs need to be in int diff --git a/examples/jax/mnist/README.md b/examples/jax/mnist/README.md index d67c1ac2b..f92229a25 100644 --- a/examples/jax/mnist/README.md +++ b/examples/jax/mnist/README.md @@ -6,13 +6,13 @@ This example uses MNIST training to demonstrate the Transformer Engine usage. Th 2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries. -3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword. +3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.autocast` context manager. If autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under autocast. If not, then autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword. 4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. 5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function. -6. Additional options: The `te.fp8_autocast` context manager has additional options +6. Additional options: The `te.autocast` context manager has additional options * FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options. ## Run ## diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 81bea4a32..d0aebeb53 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -193,8 +193,8 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast( - enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + with te.autocast( + enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() ): cnn = Net(args.use_te) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index d52e97d65..1fd40305c 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." ) parser.add_argument( "--no-comm-overlap", @@ -299,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(" |-- Forward pass", group=tp_group, debug=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world): y = model(x) if isinstance(y, tuple): out, *_ = y diff --git a/examples/pytorch/fsdp/README.md b/examples/pytorch/fsdp/README.md index 78c4d6b85..f9a49af8d 100644 --- a/examples/pytorch/fsdp/README.md +++ b/examples/pytorch/fsdp/README.md @@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd # ... ``` -**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support +**NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support (e.g.: A100), add the `--no-fp8` option to the commands shown above. diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 622228536..789389757 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -173,7 +173,7 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disables the te.fp8_autocast() context.", + help="Disables the te.autocast() context.", ) parser.add_argument( "--no-defer-init", @@ -284,11 +284,11 @@ def train(opts): dtype=opts.dtype, device="cuda", ) - # fp8_autocast needs to be given the FSDP process group for amax reductions - with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + # autocast needs to be given the FSDP process group for amax reductions + with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() - # calculate gradient and take training step outside the fp8_autocast context + # calculate gradient and take training step outside the autocast context loss.backward() optim.step() optim.zero_grad(set_to_none=True) diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index ff9e2f078..f4a48bfc9 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -52,7 +52,7 @@ def train(args, model, device, train_loader, optimizer, epoch, use_fp8): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() - with te.fp8_autocast(enabled=use_fp8): + with te.autocast(enabled=use_fp8): output = model(data) loss = F.nll_loss(output, target) loss.backward() @@ -76,7 +76,7 @@ def calibrate(model, device, test_loader, fp8): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) - with te.fp8_autocast(enabled=fp8, calibrating=True): + with te.autocast(enabled=fp8, calibrating=True): output = model(data) @@ -88,7 +88,7 @@ def test(model, device, test_loader, use_fp8): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) - with te.fp8_autocast(enabled=use_fp8): + with te.autocast(enabled=use_fp8): output = model(data) test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py index 9541ccfcb..15b146343 100644 --- a/tests/jax/test_distributed_dense.py +++ b/tests/jax/test_distributed_dense.py @@ -15,7 +15,7 @@ from utils import assert_allclose, pytest_parametrize_wrapper import transformer_engine.jax.cpp_extensions as tex -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.jax.dense import dense @@ -127,7 +127,7 @@ def test_distributed_gemm( contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension - with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + with mesh, autocast(enabled=False, mesh_resource=mesh_resource): # TE GEMM result te_result = _jitted_gemm( x_sharded, @@ -209,7 +209,7 @@ def test_te_distributed_dense_grad( contracting_dims = ((2,), (0,)) - with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + with mesh, autocast(enabled=False, mesh_resource=mesh_resource): # Test gradients w.r.t. all inputs te_grad_func = jax.jit( jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), diff --git a/tests/jax/test_distributed_helper.py b/tests/jax/test_distributed_helper.py index e74e9aa6f..c9647c13c 100644 --- a/tests/jax/test_distributed_helper.py +++ b/tests/jax/test_distributed_helper.py @@ -9,7 +9,7 @@ from utils import pytest_parametrize_wrapper, is_devices_enough from transformer_engine.jax.sharding import MeshResource, global_mesh_resource -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast def generate_mesh_configs(): @@ -26,10 +26,10 @@ def generate_mesh_configs(): class TestMeshResource(unittest.TestCase): - def test_fp8_autocast_with_mesh_resource(self): + def test_autocast_with_mesh_resource(self): for mesh_config in generate_mesh_configs(): device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = jax.sharding.Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + with mesh, autocast(enabled=False, mesh_resource=mesh_resource): self.assertEqual(mesh_resource, global_mesh_resource()) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 5fa08fa08..977d010af 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -15,7 +15,7 @@ from distributed_test_base import compare_ops from utils import pytest_parametrize_wrapper -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.common import recipe from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available @@ -133,7 +133,7 @@ def ref_func(x, gamma, beta): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): + with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) @@ -209,7 +209,7 @@ def ref_func(x, gamma): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): + with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index bf78ed3bb..339097e9c 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -23,7 +23,7 @@ ScalingMode, get_quantize_config_with_recipe, ) -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.sharding import ( @@ -210,9 +210,9 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast( + with autocast( enabled=quantization_recipe is not None, - fp8_recipe=quantization_recipe, + recipe=quantization_recipe, mesh_resource=MeshResource(), ): single_jitter = jax.jit( @@ -224,9 +224,9 @@ def _test_layernorm_mlp_grad( # Multi GPUs devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( + with mesh, autocast( enabled=quantization_recipe is not None, - fp8_recipe=quantization_recipe, + recipe=quantization_recipe, mesh_resource=mesh_resource, ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) @@ -381,8 +381,8 @@ def _test_layernorm_mlp( with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast( - enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=MeshResource() + with autocast( + enabled=use_fp8, recipe=quantization_recipe, mesh_resource=MeshResource() ): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, @@ -399,8 +399,8 @@ def _test_layernorm_mlp( device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=quantization_recipe, mesh_resource=mesh_resource + with mesh, autocast( + enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( layernorm_type=layernorm_type, diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index d9eaf314a..2bd4d862a 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count from distributed_test_base import compare_ops from utils import make_causal_mask, make_self_mask -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.jax.softmax import SoftmaxType, softmax DTYPES = [jnp.float16, jnp.bfloat16] @@ -102,7 +102,7 @@ def impl_test_softmax( collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): + with mesh, autocast(mesh_resource=mesh_resource): x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 710cc134b..5b814cb99 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -22,7 +22,7 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, @@ -771,7 +771,7 @@ def test_forward(self): ], ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): primitive_out = customcall_fused_dpa_jit(*customcall_args) primitive_out = self.cp_inverse_reorder_fn(primitive_out) @@ -788,7 +788,7 @@ def test_forward(self): assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) if self.coll_count_ref is not None: - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): target_hlo = ( customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() ) @@ -888,7 +888,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ) ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) reference_out, reference_dgrad = jitted_reference(*args) @@ -959,7 +959,7 @@ def check_dqkv(primitive, reference, pad, idx): ) if self.coll_count_ref is not None: - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index e9f71a32f..ca804625c 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -17,7 +17,7 @@ NVFP4BlockScaling, ) from transformer_engine.common.recipe import Format as FP8Format -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.jax.quantize import ( get_quantize_config, is_scaling_mode_supported, @@ -97,84 +97,78 @@ def _compare_nvfp4_scaling(self, test): ) @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast_delayed_scaling(self): + def test_autocast_delayed_scaling(self): self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): + with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()): self._check_default_state() self._check_default_state() ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(ds) self._check_default_state() @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast_current_scaling(self): + def test_autocast_current_scaling(self): self._check_default_state() - with fp8_autocast( - enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource() - ): + with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()): self._check_default_state() self._check_default_state() cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) - def test_fp8_autocast_mxfp8_block_scaling(self): + def test_autocast_mxfp8_block_scaling(self): self._check_default_state() - with fp8_autocast( - enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource() - ): + with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()): self._check_default_state() self._check_default_state() bs = MXFP8BlockScaling() - with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) - def test_fp8_autocast_nvfp4_block_scaling(self): + def test_autocast_nvfp4_block_scaling(self): self._check_default_state() - with fp8_autocast( - enabled=False, fp8_recipe=NVFP4BlockScaling(), mesh_resource=MeshResource() - ): + with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()): self._check_default_state() self._check_default_state() bs = NVFP4BlockScaling() - with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_nvfp4_scaling(bs) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 6f672ade7..d1b2535c4 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -28,7 +28,7 @@ is_fp8_available, update_collections, TensorSource, - fp8_autocast, + autocast, ) from transformer_engine.jax.sharding import MeshResource @@ -507,14 +507,14 @@ def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" # Ensure FP8 disabled. # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=False, mesh_resource=MeshResource()): + with autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" # Ensure FP8 disabled. # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=False, mesh_resource=MeshResource()): + with autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -522,7 +522,7 @@ def test_backward(self, data_shape, dtype, attrs): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -530,7 +530,7 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index d490c235b..1edffaf48 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -8,16 +8,15 @@ from contextlib import nullcontext import torch import torch.distributed as dist -from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn -from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Tensor, +from transformer_engine.pytorch import ( + autocast, + DotProductAttention, Float8Quantizer, Float8CurrentScalingQuantizer, ) @@ -306,7 +305,7 @@ def run_dpa_with_cp( ############ run without CP ############ logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() with fp8_context: @@ -396,7 +395,7 @@ def run_dpa_with_cp( if dtype == "fp8": core_attn.fp8_initialized = False core_attn.fp8_meta_tensors_initialized = False - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index e3a4de73b..7dc6caeb8 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -10,13 +10,22 @@ import pytest import torch +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype from transformer_engine.common import recipe -from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init -from transformer_engine.pytorch.attention.dot_product_attention import ( +from transformer_engine.pytorch import ( + TransformerLayer, + autocast, + quantized_model_init, DotProductAttention, + MultiheadAttention, + get_device_compute_capability, + Quantizer, + is_fp8_available, + is_bf16_available, +) +from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) -from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -29,18 +38,14 @@ fused_attn_fwd, ) from transformer_engine.pytorch.distributed import CudaRNGStatesTracker -import transformer_engine.pytorch.fp8 as fp8 from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( - get_device_compute_capability, init_method_normal, scaled_init_method_normal, - is_bf16_compatible, ) from transformer_engine.pytorch.utils import get_cudnn_version import transformer_engine_torch as tex from transformer_engine.pytorch.tensor.quantized_tensor import ( - Quantizer, prepare_for_saving, restore_from_saved, ) @@ -56,7 +61,7 @@ ) # Check if hardware supports FP8 -fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) # Reset RNG seed and states seed = 1234 @@ -67,12 +72,12 @@ @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield - fp8.FP8GlobalStateManager.reset() + FP8GlobalStateManager.reset() # Define F16 data types to test param_types = [torch.float16] -if is_bf16_compatible(): +if is_bf16_available(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] @@ -1592,7 +1597,7 @@ def get_model(dtype, config): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): + with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -1609,7 +1614,7 @@ def get_model(dtype, config): block = get_model(dtype, config) for i in range(steps // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_enabled, recipe=fp8_recipe): output = block(hidden_states, None) loss = output.sum() loss.backward() @@ -1644,7 +1649,7 @@ def get_model(dtype, config): assert not param_grads, "Oops!" for i in range((steps + 1) // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_enabled, recipe=fp8_recipe): output = block(hidden_states, None) loss = output.sum() loss.backward() @@ -1820,7 +1825,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): + with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: PE = RotaryPositionEmbedding(dim=config.head_dim_qk) @@ -1892,7 +1897,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, attn_mask_type=config.attn_mask_type, @@ -2110,7 +2115,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: return _DUMMY_CUDA_RNG_STATE_TRACKER qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - with fp8_model_init(enabled=fp8_dpa): + with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, config.head_dim_qk, @@ -2202,7 +2207,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") - with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_dpa, recipe=fp8_recipe): out = dpa( inp[0], inp[1], @@ -2343,7 +2348,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ) mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with autocast(enabled=True, recipe=fp8_recipe): out = mha(inp, cu_seqlens, config.max_seqlen_q) out.backward(out_grad) @@ -2541,7 +2546,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) proj_dgrad = ctx.dO_quantizer(grad_output) - fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0f00b8b0e..2c7f9d857 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -10,7 +10,7 @@ import pytest import torch -from transformer_engine.pytorch.utils import ( +from transformer_engine.pytorch import ( get_device_compute_capability, get_cudnn_version, ) diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py index 00200c62d..0dd5ba601 100644 --- a/tests/pytorch/attention/test_cp_utils.py +++ b/tests/pytorch/attention/test_cp_utils.py @@ -5,7 +5,6 @@ """Unit tests for context parallel utils.""" import torch import unittest -from typing import Tuple from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_batch_on_this_cp_rank, pad_thd_sequences_for_cp, diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index 4dc3af411..864276a67 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -14,20 +14,22 @@ import torch from torch.distributions import Exponential -from transformer_engine.pytorch import make_graphed_callables -from transformer_engine.common import recipe -from transformer_engine.pytorch import fp8_autocast, fp8_model_init -from transformer_engine.pytorch.transformer import ( +from transformer_engine.pytorch import ( + make_graphed_callables, + autocast, + quantized_model_init, TransformerLayer, + DotProductAttention, + InferenceParams, + is_bf16_available, ) -from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams +from transformer_engine.common import recipe from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, ) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, - is_bf16_compatible, ) _current_file = pathlib.Path(__file__).resolve() @@ -42,7 +44,7 @@ reset_rng_states() param_types = [torch.float16] -if is_bf16_compatible(): +if is_bf16_available(): param_types.append(torch.bfloat16) model_configs_infer = { @@ -238,7 +240,7 @@ def get_model( if module == "TransformerLayer": hidden_size = config.head_dim_qk * config.num_heads - with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): + with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe): model = [ TransformerLayer( hidden_size=hidden_size, @@ -261,7 +263,7 @@ def get_model( for layer_number in range(1, num_layers + 1) ] if module == "DotProductAttention": - with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): + with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe): model = [ DotProductAttention( kv_channels=config.head_dim_qk, @@ -559,9 +561,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g model[i], sample_args, num_warmup_iters=10, - fp8_enabled=is_fp8, + enabled=is_fp8, sample_kwargs=sample_kwargs, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, ) for i in range(num_layers) ] @@ -654,7 +656,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g if inference_params.is_paged: inference_params.cache_manager.print_cache() incremental_output = incremental_inputs - with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=is_fp8, recipe=fp8_recipe): for m in model: incremental_output = m( *incremental_output, diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 716c16056..fee2189fa 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -16,7 +16,7 @@ import transformer_engine_torch as tex import nvdlfw_inspect.api as debug_api from transformer_engine.debug import set_weight_tensor_tp_group_reduce -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import is_fp8_available from test_numerics import ( _emulate_linear, @@ -45,7 +45,7 @@ all_boolean = [True, False] TEST_NR = 0 -fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +fp8_available = is_fp8_available() def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): @@ -117,7 +117,7 @@ def backward(ctx, grad_output): def _run_forward_backward(x, model, parallel_mode=None, group=None): - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE): y = model(x) y.requires_grad_(True) @@ -413,13 +413,13 @@ def test_log_expert_parallel(**kwargs): ) # data parallel model = _init_model(weight, parallel_mode=None, name="linear1") model1 = _init_model(weight, parallel_mode=None, name="linear2") - with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE): y1 = model(x) y2 = model1(x) y = y1 + y2 y.sum().backward() debug_api.step() - with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE): y = model(x) if WORLD_RANK != 0: y = y + model1(x) @@ -532,7 +532,7 @@ def test_per_tensor_scaling( LOSS_MULTIPLIER = 100 - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE): y = model(x) model.zero_grad() if parallel_mode == "column": diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index d28db1647..fbf619d48 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -3,7 +3,7 @@ # See LICENSE for license information. import torch -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch import Float8Tensor, Float8Quantizer import nvdlfw_inspect.api as debug_api diff --git a/tests/pytorch/debug/test_config.py b/tests/pytorch/debug/test_config.py index 71715a686..9b6bcd1cd 100644 --- a/tests/pytorch/debug/test_config.py +++ b/tests/pytorch/debug/test_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -import pathlib, os +import pathlib from nvdlfw_inspect.config_manager import ConfigManager diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index dcc9861c8..e9d074821 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -8,18 +8,22 @@ import torch import tempfile from transformer_engine.common import recipe -from transformer_engine.pytorch.fp8 import RecipeState import pytest import contextlib import os -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import ( + is_fp8_available, + is_mxfp8_available, + is_fp8_block_scaling_available, +) +from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.debug.pytorch.debug_state import TEDebugState -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( + return_reason=True ) LOG_QUANTIZED_CONFIG_BASE = """ @@ -128,7 +132,7 @@ def test_sanity(feature_dirs): inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() for _ in range(10): - with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): + with te.autocast(recipe=recipe.DelayedScaling()): output = model(inp) loss = output.sum() loss.backward() @@ -232,7 +236,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): for i in range(20): x = torch.randn(4, 128, 128).cuda() - with te.fp8_autocast(enabled=True): + with te.autocast(enabled=True): y = model(x) y.sum().backward() debug_api.step() diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py index 749fa16bc..2ad2c8fb8 100644 --- a/tests/pytorch/debug/test_numerics.py +++ b/tests/pytorch/debug/test_numerics.py @@ -17,19 +17,19 @@ import transformer_engine.pytorch as tepytorch import transformer_engine_torch as tex from transformer_engine.common.recipe import DelayedScaling, Format -from transformer_engine.pytorch.fp8 import _default_sf_compute -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch.quantization import _default_sf_compute +from transformer_engine.pytorch import ( Float8Quantizer, Float8CurrentScalingQuantizer, + is_fp8_available, ) from transformer_engine.pytorch.module.base import ( _2X_ACC_DGRAD, _2X_ACC_FPROP, _2X_ACC_WGRAD, ) -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) all_boolean = [True, False] FP8_FORMAT = Format.HYBRID @@ -250,7 +250,7 @@ def _init_model(weight): def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True): - with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE): + with tepytorch.autocast(enabled=fp8, recipe=FP8_RECIPE): y = model(x, is_first_microbatch=is_first_microbatch) (y.sum() * loss_scale).backward() debug_api.step() @@ -547,7 +547,7 @@ def run_per_tensor_scaling( LOSS_MULTIPLIER = 100 - with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + with tepytorch.autocast(enabled=True, recipe=FP8_RECIPE): y = model(x, is_first_microbatch=True) model.zero_grad() y.retain_grad() diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index e4ce35be6..97be3003d 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -7,11 +7,10 @@ import nvdlfw_inspect.api as debug_api import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from test_numerics import create_config_file -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) B, S, H, D = 64, 64, 64, 64 @@ -68,7 +67,7 @@ def _get_model(model_key): def _run_forward_backward(model, fp8): for _ in range(3): inp = torch.randn((S, B, H)).cuda() - with te.fp8_autocast(enabled=fp8): + with te.autocast(enabled=fp8): out = model(inp) out.sum().backward() debug_api.step() diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py index 875905c78..976991633 100644 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -21,13 +21,13 @@ Recipe, ) import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( + QuantizedTensor, Float8Tensor, - Float8CurrentScalingQuantizer, + Float8BlockwiseQTensor, ) +from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 from transformer_engine.pytorch.tensor.utils import replace_raw_data -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor def _get_raw_data(quantized_tensor): @@ -439,7 +439,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): } # Create model with FP8 weights - with te.fp8.fp8_model_init( + with te.quantized_model_init( enabled=quantization is not None, recipe=quantization_recipe(quantization), preserve_high_precision_init_val=True, @@ -475,17 +475,17 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): # Choose based on rank to make sure the inputs of different ranks are different. x = inputs[rank] - with te.fp8.fp8_autocast( + with te.autocast( enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, + recipe=quantization_recipe(quantization), + amax_reduction_group=mock_group, ): y_fp8 = model_fp8(x) - with te.fp8_autocast( + with te.autocast( enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, + recipe=quantization_recipe(quantization), + amax_reduction_group=mock_group, ): y = model(x) @@ -573,7 +573,7 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} # Create model with FP8 weights - with te.fp8.fp8_model_init( + with te.quantized_model_init( enabled=quantization is not None, recipe=quantization_recipe(quantization), preserve_high_precision_init_val=True, @@ -615,17 +615,17 @@ def _test_cast_master_weights_to_fp8(quantization, dp_group): # Choose based on rank to make sure the inputs of different ranks are different. x = inputs[rank] - with te.fp8.fp8_autocast( + with te.autocast( enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, + recipe=quantization_recipe(quantization), + amax_reduction_group=mock_group, ): y_fp8 = model_fp8(x) - with te.fp8_autocast( + with te.autocast( enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, + recipe=quantization_recipe(quantization), + amax_reduction_group=mock_group, ): y = model(x) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index e32f64cf1..8026fc0a3 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -110,9 +110,9 @@ def _train(args): build_model_context = nullcontext build_model_context_args = {} - from transformer_engine.pytorch import fp8_model_init + from transformer_engine.pytorch import quantized_model_init - build_model_context = fp8_model_init + build_model_context = quantized_model_init build_model_context_args["enabled"] = True # Build the model with the specified context diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 6d9e2f152..df0e4a216 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -18,9 +18,12 @@ from torch.distributed.elastic.multiprocessing.errors import record import transformer_engine.pytorch as te +from transformer_engine.pytorch import ( + Float8Tensor, + Float8Quantizer, + MXFP8Quantizer, +) import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.module.base import ( fill_userbuffers_buffer_for_all_gather, get_cublas_workspace_size_bytes, @@ -171,12 +174,12 @@ def _parse_args(argv=None, namespace=None): opts.p2p = True if opts.atomic: - if not te.fp8.check_fp8_support(): + if not te.is_fp8_available(): assert opts.quantization == "none", "Atomic GEMM is only supported in FP8." opts.quantization = "fp8" if opts.fp8_output: - assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute." + assert opts.quantization == "fp8", "FP8 output is only supported with FP8 compute." return opts diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 2a6e55b2c..b2bd6dd77 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." ) parser.add_argument( "--quantization", @@ -438,7 +438,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, ) - with te.fp8_model_init(enabled=opts.fp8_init): + with te.quantized_model_init(enabled=opts.fp8_init): test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs) dist_print("Initialized test model...", debug=True) if WORLD_RANK == 0: @@ -450,7 +450,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ref_args, ref_kwargs, _ = _get_layer_args( opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True ) - with te.fp8_model_init(enabled=opts.fp8_init): + with te.quantized_model_init(enabled=opts.fp8_init): ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs) dist_print("Initialized reference model...", debug=True) for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): @@ -473,7 +473,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): layer_contexts = [ ( - partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world) + partial( + te.autocast, enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world + ) if opts.num_layers_at_start_in_bf16 <= i and i < (opts.num_layers - opts.num_layers_at_end_in_bf16) else nullcontext diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index a4aa74bd8..63ecb548b 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -26,8 +26,7 @@ Recipe, QParams, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE from transformer_engine.pytorch.distributed import gather_along_first_dim from run_layer_with_overlap import _compare_tensors @@ -75,7 +74,7 @@ def quantization_recipe() -> Recipe: return Float8BlockScaling() if QUANTIZATION == "nvfp4": return nvfp4_vanilla() - return te.fp8.get_default_fp8_recipe() + return te.quantization.get_default_fp8_recipe() def main(argv=None, namespace=None): @@ -316,15 +315,15 @@ def _apply_models( _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True input_single_node.requires_grad_() input_distributed.requires_grad_() - with te.fp8_autocast( + with te.autocast( enabled=QUANTIZATION is not None, - fp8_recipe=quantization_recipe(), + recipe=quantization_recipe(), ): output_single_node = model_single_node(input_single_node, **kwargs) - with te.fp8_autocast( + with te.autocast( enabled=QUANTIZATION is not None, - fp8_recipe=quantization_recipe(), - fp8_group=NCCL_WORLD, + recipe=quantization_recipe(), + amax_reduction_group=NCCL_WORLD, ): output_distributed = model_distributed(input_distributed, **kwargs) return output_single_node, output_distributed diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 40be8e1f0..ccbc3259b 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -9,21 +9,18 @@ import os import sys from functools import wraps -import math import transformer_engine.pytorch as te import torch from torch import nn import torch.distributed as dist -import transformer_engine_torch as tex from transformer_engine.common.recipe import ( NVFP4BlockScaling, - Format, Recipe, QParams, CustomRecipe, ) -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.experimental import utils @@ -506,7 +503,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): ) # run the recipe under test - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.autocast(enabled=True, recipe=recipe): y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear( x, w, @@ -524,7 +521,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): # run the reference reference_recipe = quantization_reference_recipe() - with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe): + with te.autocast(enabled=True, recipe=reference_recipe): y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( x, w, @@ -700,7 +697,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs ) # run the recipe under test - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.autocast(enabled=True, recipe=recipe): y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear( x, w, @@ -717,7 +714,7 @@ def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs # run the reference reference_recipe = quantization_reference_recipe() - with te.fp8_autocast(enabled=True, fp8_recipe=reference_recipe): + with te.autocast(enabled=True, recipe=reference_recipe): y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( TestDistributedLayerNormLinearBase.run_layernorm_linear( x, diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index fd802c910..5bf46b8d5 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -8,15 +8,15 @@ import pytest import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available if torch.cuda.device_count() < 2: pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( + return_reason=True ) TEST_ROOT = Path(__file__).parent.resolve() diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 74d1dc69c..ddb31c30f 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -9,13 +9,12 @@ import torch import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager if torch.cuda.device_count() < 2: pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) RNG_SEED: int = 42 SEQ_LENGTH: int = 1024 diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index af0f0e931..5844d8109 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -20,16 +20,15 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( + QuantizedTensor, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, + is_bf16_available, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Import utility functions @@ -39,9 +38,9 @@ # Check what quantization schemes are supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_mxfp8_available(return_reason=True) quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) @@ -427,7 +426,7 @@ def _test_basic_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -440,7 +439,7 @@ def _test_basic_linear( with torch.no_grad(): op.weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -593,7 +592,7 @@ def _test_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -612,7 +611,7 @@ def _test_linear( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -759,7 +758,7 @@ def _test_mlp( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.GELU(), te_ops.Linear( @@ -795,7 +794,7 @@ def _test_mlp( # Warmup steps for _ in range(3): - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) x_test.grad = None @@ -806,7 +805,7 @@ def _test_mlp( model[3].bias.grad = None # Forward and backward step - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -944,7 +943,7 @@ def ref_amax_and_scale( amax_history_len=amax_history_len, amax_compute_algo=amax_compute_algo, ) - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -1004,7 +1003,7 @@ def run_parallel_tests() -> None: if rank == 0: print(f"Running _test_linear with {config=}") quantization, tensor_parallel_mode, sequence_parallel = config - dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, @@ -1018,7 +1017,7 @@ def run_parallel_tests() -> None: if rank == 0: print(f"Running _test_mlp with {config=}") quantization, sequence_parallel = config - dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 _test_mlp( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index d6ddfe27c..24112cc9f 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -16,23 +16,23 @@ import pytest import torch +from typing import Optional, Iterable + import transformer_engine import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, + QuantizedTensor, + Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.utils import is_bf16_compatible # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -40,8 +40,8 @@ from utils import dtype_tols, make_recipe, str_to_dtype # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) @@ -301,7 +301,7 @@ def _test_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): + with te.quantized_model_init(enabled=quantized_compute, recipe=recipe): ops = [] linear_op = None bias_op = None @@ -351,7 +351,7 @@ def _test_linear( bias_op.bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index d09c530cb..97a69e779 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -8,7 +8,7 @@ import pytest import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te """ Distributed numerics tests @@ -26,12 +26,12 @@ if torch.cuda.device_count() < 2: pytest.skip("Distributed training needs at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) -nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py index 890a24804..fd6ef65e0 100644 --- a/tests/pytorch/distributed/test_numerics_exact.py +++ b/tests/pytorch/distributed/test_numerics_exact.py @@ -8,7 +8,7 @@ import pytest import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te """ Distributed numerics tests @@ -23,12 +23,12 @@ if torch.cuda.device_count() < 2: pytest.skip("Distributed training needs at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) -nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) diff --git a/tests/pytorch/distributed/test_sanity.py b/tests/pytorch/distributed/test_sanity.py index 39494a92b..fbbbe2997 100644 --- a/tests/pytorch/distributed/test_sanity.py +++ b/tests/pytorch/distributed/test_sanity.py @@ -7,8 +7,7 @@ import pytest import torch import transformer_engine -from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention -from transformer_engine.pytorch import TransformerLayer, Linear +from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index f5c186a3b..8fe4e8bc7 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -6,13 +6,12 @@ import pytest import subprocess from pathlib import Path -from transformer_engine.pytorch import torch_version -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te import torch -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() @@ -34,7 +33,7 @@ def _run_test(fp_init, sharding_dims): @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") -@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") +@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) def test_distributed(fp8_init, sharding_dims): diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 42837fb40..77cfaaffe 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -4,16 +4,15 @@ import pytest import torch -import transformer_engine as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental import utils -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) def check_nvfp4_gemm_versus_reference( diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index 1d1467640..44f222b9d 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -4,15 +4,13 @@ import pytest import torch -import transformer_engine as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.distributed import fp8_autocast +import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch.experimental import quantization_nvfp4 from transformer_engine.pytorch.experimental import utils -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) class GetRecipes: @@ -152,16 +150,16 @@ def check_nvfp4_module_versus_reference( # Create native module print("\nCreate native module") - if module_class == te.pytorch.Linear: - native_module = te.pytorch.Linear( + if module_class == te.Linear: + native_module = te.Linear( in_features=in_features, out_features=out_features, bias=bias, device=device, params_dtype=x_dtype, ) - elif module_class == te.pytorch.LayerNormLinear: - native_module = te.pytorch.LayerNormLinear( + elif module_class == te.LayerNormLinear: + native_module = te.LayerNormLinear( in_features=in_features, out_features=out_features, bias=bias, @@ -176,16 +174,16 @@ def check_nvfp4_module_versus_reference( # Create reference module print("Create reference module") - if module_class == te.pytorch.Linear: - ref_module = te.pytorch.Linear( + if module_class == te.Linear: + ref_module = te.Linear( in_features=in_features, out_features=out_features, bias=bias, device=device, params_dtype=x_dtype, ) - elif module_class == te.pytorch.LayerNormLinear: - ref_module = te.pytorch.LayerNormLinear( + elif module_class == te.LayerNormLinear: + ref_module = te.LayerNormLinear( in_features=in_features, out_features=out_features, bias=bias, @@ -232,13 +230,13 @@ def check_nvfp4_module_versus_reference( grad_output = grad_output_val.clone().detach() # Native forward/backward - with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + with te.autocast(enabled=True, recipe=nvfp4_recipe): # enable weight cache by giving is_first_microbatch y_native = native_module(x_native, is_first_microbatch=(step == 0)) y_native.backward(grad_output) # Reference forward/backward - with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe): + with te.autocast(enabled=True, recipe=nvfp4_ref_recipe): y_ref = ref_module(x_ref) y_ref.backward(grad_output) @@ -361,7 +359,7 @@ def test_nvfp4_linear_versus_reference( pytest.skip("RHT is only supported for bfloat16 input") check_nvfp4_module_versus_reference( - module_class=te.pytorch.Linear, + module_class=te.Linear, in_features=in_features, out_features=out_features, bias=bias, @@ -394,7 +392,7 @@ def check_nvfp4_layernorm_linear_versus_reference( reset_rng_states() # Native module - native_module = te.pytorch.LayerNormLinear( + native_module = te.LayerNormLinear( in_features=in_features, out_features=out_features, bias=bias, @@ -406,7 +404,7 @@ def check_nvfp4_layernorm_linear_versus_reference( # Reference module reset_rng_states() - ref_module = te.pytorch.LayerNormLinear( + ref_module = te.LayerNormLinear( in_features=in_features, out_features=out_features, bias=bias, @@ -456,12 +454,12 @@ def check_nvfp4_layernorm_linear_versus_reference( grad_output = grad_output_val.clone().detach() # Native forward/backward - with fp8_autocast(enabled=True, fp8_recipe=nvfp4_recipe): + with te.autocast(enabled=True, recipe=nvfp4_recipe): y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) y_native.backward(grad_output) # Reference forward/backward - with fp8_autocast(enabled=True, fp8_recipe=nvfp4_ref_recipe): + with te.autocast(enabled=True, recipe=nvfp4_ref_recipe): y_ref, ln_out_ref = ref_module(x_ref) y_ref.backward(grad_output) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index cdcb2df1d..8c2444557 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -4,20 +4,16 @@ import pytest import torch -import transformer_engine as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.nvfp4_tensor import ( - NVFP4Quantizer, -) -from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental import utils -from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) def unpack_fp4(x: torch.Tensor) -> torch.Tensor: diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 494fa63c0..6f2f846a3 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -9,25 +9,18 @@ # Due to the structure of NVFP4Quantizer, we need to test the RHT functionality # together with the quantization functionality. -from typing import Tuple -import math - -import transformer_engine as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.nvfp4_tensor import ( - NVFP4Quantizer, -) from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.experimental import utils -from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp4_te_dtype import pytest import torch -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) def unpack_fp4(x: torch.Tensor) -> torch.Tensor: diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py index 46077eb20..0842de9ea 100755 --- a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -4,10 +4,10 @@ import pytest import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer +import transformer_engine.pytorch as te +from transformer_engine.pytorch import NVFP4Quantizer -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_nvfp4_available() +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) seed = 12345 torch.manual_seed(seed) diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 16e7feb1b..99a3af0d6 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -12,13 +12,15 @@ import pytest import torch +from typing import Optional + import transformer_engine.pytorch as te from utils import make_recipe # Check supported quantization schemes -fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) # Test cases for loading checkpoint files @@ -65,16 +67,16 @@ def _make_module(name: str) -> torch.nn.Module: if name == "ops_linear": return te.ops.Linear(1, 1) if name == "linear.fp8": - with te.fp8_model_init(recipe=make_recipe("fp8")): + with te.quantized_model_init(recipe=make_recipe("fp8")): return te.Linear(16, 16) if name == "ops_linear.fp8": - with te.fp8_model_init(recipe=make_recipe("fp8")): + with te.quantized_model_init(recipe=make_recipe("fp8")): return te.ops.Linear(16, 16) if name == "linear.mxfp8": - with te.fp8_model_init(recipe=make_recipe("mxfp8")): + with te.quantized_model_init(recipe=make_recipe("mxfp8")): return te.Linear(32, 32) if name == "ops_linear.mxfp8": - with te.fp8_model_init(recipe=make_recipe("mxfp8")): + with te.quantized_model_init(recipe=make_recipe("mxfp8")): return te.ops.Linear(32, 32) raise ValueError(f"Unrecognized module name ({name})") diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 0e01f0b04..64da83a21 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -12,14 +12,13 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from utils import ModelConfig, get_available_attention_backends # Check supported quantization schemes -fp8_available, _ = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +fp8_available = te.is_fp8_available() +mxfp8_available = te.is_mxfp8_available() quantization_recipes: Optional[recipe.Recipe] = [None] if fp8_available: @@ -79,9 +78,9 @@ def _warmup_model( """Perform forward and backward pass""" tensor = _make_input() for module in modules: - with te.fp8_autocast( + with te.autocast( enabled=quantization_recipe is not None, - fp8_recipe=quantization_recipe, + recipe=quantization_recipe, ): tensor = module(tensor) tensor.sum().backward() @@ -159,8 +158,8 @@ def _measure_cached_memory( tensor = inp memory_before_forward = torch.cuda.memory_allocated() / (1024**2) for module in modules: - with te.fp8_autocast( - enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe + with te.autocast( + enabled=quantization_recipe is not None, recipe=quantization_recipe ), offload_context: tensor = module(tensor) tensor = sync_function(tensor) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index be7a65deb..fa8754d60 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -13,20 +13,23 @@ Linear, MultiheadAttention, TransformerLayer, - fp8_autocast, - fp8_model_init, + autocast, + quantized_model_init, make_graphed_callables, + is_fp8_available, + is_fp8_block_scaling_available, + is_mxfp8_available, + is_bf16_available, ) -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe from utils import ModelConfig, reset_rng_states # Check if FP8 is supported. -fp8_available, _ = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() -mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +fp8_available = is_fp8_available() +fp8_block_scaling_available = is_fp8_block_scaling_available() +mxfp8_available = is_mxfp8_available() # Reset RNG states. reset_rng_states() @@ -93,7 +96,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher dtypes.append(torch.bfloat16) @@ -201,7 +204,7 @@ def _test_cuda_graphs( fp8_weight_caching = False # Create modules. - with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): + with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe): if module == "transformer": modules = [ TransformerLayer( @@ -281,9 +284,9 @@ def _test_cuda_graphs( model, (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, - fp8_enabled=fp8, - fp8_weight_caching=fp8_weight_caching, - fp8_recipe=fp8_recipe, + enabled=fp8, + cache_quantized_params=fp8_weight_caching, + recipe=fp8_recipe, ) elif graph_mode == "individual": # Graph individual modules. @@ -292,9 +295,9 @@ def _test_cuda_graphs( module, (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, - fp8_enabled=fp8, - fp8_weight_caching=fp8_weight_caching, - fp8_recipe=fp8_recipe, + enabled=fp8, + cache_quantized_params=fp8_weight_caching, + recipe=fp8_recipe, ) for module in modules ] @@ -311,7 +314,7 @@ def _test_cuda_graphs( for grad_accumulation_step in range(2): input_ = generate_data(model_config, dtype) grad_output = generate_data(model_config, dtype, requires_grad=False) - with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8, recipe=fp8_recipe): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 @@ -455,7 +458,7 @@ def _test_cuda_graphs_with_dot_product_attention( model, generate_data_for_dot_product_attention(model_config, dtype, warmup=True), num_warmup_iters=10, - fp8_enabled=False, + enabled=False, ) # Forward and backward passes. diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index cb840f197..516354a34 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -5,23 +5,23 @@ import pytest import torch -import transformer_engine as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.common import recipe -from transformer_engine.pytorch.fp8 import check_fp8_support, fp8_autocast -from transformer_engine.pytorch import Linear -import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.module.layernorm_linear import LayerNormLinear -from transformer_engine.pytorch.module.layernorm_mlp import LayerNormMLP -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( + autocast, + Linear, + LayerNormLinear, + LayerNormMLP, + GroupedLinear, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.module.grouped_linear import GroupedLinear +import transformer_engine.pytorch.ops as te_ops @pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"]) def test_custom_recipe_sanity(module_type): - available, reason = check_fp8_support() + available, reason = te.is_fp8_available(return_reason=True) if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") @@ -57,7 +57,7 @@ def quantizer_factory(role): custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) # Execute with custom recipe - with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + with autocast(enabled=True, recipe=custom_recipe): out = model(inp) loss = out.float().sum() loss.backward() @@ -67,7 +67,7 @@ def quantizer_factory(role): def test_custom_recipe_grouped_linear_sanity(): - available, reason = check_fp8_support() + available, reason = te.is_fp8_available(return_reason=True) if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") @@ -93,7 +93,7 @@ def quantizer_factory(role): custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) - with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + with autocast(enabled=True, recipe=custom_recipe): out = model(inp, m_splits) loss = out.float().sum() loss.backward() @@ -102,7 +102,7 @@ def quantizer_factory(role): def test_custom_recipe_matches_current_scaling(): - available, reason = check_fp8_support() + available, reason = te.is_fp8_available(return_reason=True) if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") @@ -124,7 +124,7 @@ def test_custom_recipe_matches_current_scaling(): # Reference: use Float8CurrentScaling recipe ref_recipe = recipe.Float8CurrentScaling() - with fp8_autocast(enabled=True, fp8_recipe=ref_recipe): + with autocast(enabled=True, recipe=ref_recipe): out_ref = model_ref(inp_ref) # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -155,7 +155,7 @@ def quantizer_factory(role): custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) - with fp8_autocast(enabled=True, fp8_recipe=custom_recipe): + with autocast(enabled=True, recipe=custom_recipe): out_custom = model_custom(inp_custom) # Assert dtypes for custom quantizers match reference mapping cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] @@ -189,7 +189,7 @@ def quantizer_factory(role): def test_custom_recipe_ops_linear_2_1_layout(): - available, reason = check_fp8_support() + available, reason = te.is_fp8_available(return_reason=True) if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") @@ -212,7 +212,7 @@ def quantizer_factory(role): custom = recipe.CustomRecipe(qfactory=quantizer_factory) - with fp8_autocast(enabled=True, fp8_recipe=custom): + with autocast(enabled=True, recipe=custom): out = op(inp) loss = out.float().sum() loss.backward() @@ -221,7 +221,7 @@ def quantizer_factory(role): def test_custom_recipe_factory_invocation_counts_and_cycling(): - available, reason = check_fp8_support() + available, reason = te.is_fp8_available(return_reason=True) if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") @@ -256,7 +256,7 @@ def quantizer_factory(role): # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), # and backward to build 2 quantizers (cycled from 1 factory). - with fp8_autocast(enabled=True, fp8_recipe=custom): + with autocast(enabled=True, recipe=custom): out = op(inp) loss = out.float().sum() loss.backward() @@ -270,7 +270,7 @@ def quantizer_factory(role): def test_factories_return_distinct_instances_and_buffers(): - available, reason = check_fp8_support() + available, reason = te.is_fp8_available(return_reason=True) if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py index 7d6d52362..4ce522495 100644 --- a/tests/pytorch/test_deferred_init.py +++ b/tests/pytorch/test_deferred_init.py @@ -4,7 +4,6 @@ import pytest import torch -import torch.distributed as dist import transformer_engine.pytorch as te diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index bdc73519b..9ae8a6069 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -4,22 +4,20 @@ import pytest import torch -import transformer_engine as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.utils import get_device_compute_capability -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( +from transformer_engine.pytorch import ( Float8BlockQuantizer, - Float8BlockwiseQTensor, + get_device_compute_capability, ) from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() + supported = te.is_fp8_block_scaling_available() emulated = get_device_compute_capability() >= (10, 0) return supported and not emulated diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 51e0d1ec9..153f0b7e0 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -8,15 +8,12 @@ import pathlib import pytest import torch -import transformer_engine as te -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te from transformer_engine.common.recipe import Float8BlockScaling -from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( +from transformer_engine.pytorch import ( Float8BlockQuantizer, - Float8BlockwiseQTensor, + get_device_compute_capability, ) from references.blockwise_quantizer_reference import ( BlockwiseQuantizerReference, @@ -32,7 +29,7 @@ tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() +recipe_available, reason_for_no_recipe = te.is_fp8_block_scaling_available(return_reason=True) recipe_emulated = get_device_compute_capability() >= (10, 0) diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 82bd61a01..e4d6ce365 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -8,11 +8,9 @@ import pytest import transformer_engine.pytorch as te -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8CurrentScaling -from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype +from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory @@ -23,7 +21,7 @@ # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) class GetRecipes: @@ -394,7 +392,7 @@ def compare_recipe( # recipe1 using_fp8_recipe = recipe1() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + with autocast(enabled=True, recipe=recipe1()): y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) else: y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) @@ -402,7 +400,7 @@ def compare_recipe( # recipe2 using_fp8_recipe = recipe2() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + with autocast(enabled=True, recipe=recipe2()): y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) else: y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) @@ -617,7 +615,7 @@ def compare_recipe( # recipe1 using_fp8_recipe = recipe1() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + with autocast(enabled=True, recipe=recipe1()): y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( x, w, @@ -639,7 +637,7 @@ def compare_recipe( # recipe2 using_fp8_recipe = recipe2() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + with autocast(enabled=True, recipe=recipe2()): y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( x, w, diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 39062b442..c59f8d8c6 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -11,12 +11,11 @@ import torch import transformer_engine.common.recipe -import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( +from transformer_engine.pytorch import ( Float8BlockQuantizer, Float8BlockwiseQTensor, + get_device_compute_capability, ) -from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # PyTorch tensor dtypes diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index a5c97a950..b7ddf0e8a 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -11,13 +11,11 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported import transformer_engine_torch as tex @@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: DimsType = Union[Iterable[int], int] # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) # delayed scaling diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 8adabd751..efef64a1e 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -11,14 +11,11 @@ from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling -from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.pytorch.utils import is_bf16_compatible -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available from transformer_engine.pytorch.utils import gpu_autocast_ctx # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) class TestFusedOptimizer: @@ -188,7 +185,7 @@ def gen_precision_aware_test( build_model_context = nullcontext build_model_context_args = {} if use_fp8_params: - build_model_context = fp8_model_init + build_model_context = quantized_model_init build_model_context_args["enabled"] = True with build_model_context(**build_model_context_args): @@ -286,7 +283,7 @@ def test_fp32_no_master(self): exp_avg_sq_dtype=torch.float32, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp32_master(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -298,7 +295,7 @@ def test_fp32_master(self): exp_avg_sq_dtype=torch.float32, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp32_master_store_param_remainders(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -311,7 +308,7 @@ def test_fp32_master_store_param_remainders(self): store_param_remainders=True, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp16_master(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -325,7 +322,7 @@ def test_fp16_master(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_grad(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -339,7 +336,7 @@ def test_bf16_grad(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp16_exp_avg(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -353,7 +350,7 @@ def test_fp16_exp_avg(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_exp_avg(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -367,7 +364,7 @@ def test_bf16_exp_avg(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): self.gen_precision_aware_test( @@ -382,7 +379,7 @@ def test_fp8_exp_avg(self): master_atol=1e-2, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp16_exp_avg_sq(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -396,7 +393,7 @@ def test_fp16_exp_avg_sq(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_exp_avg_sq(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -410,7 +407,7 @@ def test_bf16_exp_avg_sq(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): self.gen_precision_aware_test( @@ -424,7 +421,7 @@ def test_fp8_exp_avg_sq(self): skip_assert=True, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_model_weight_cast(self): dtype = torch.bfloat16 model = MultiheadAttention( @@ -468,7 +465,7 @@ def test_bf16_model_weight_cast(self): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_model_weight_cast(self): dtype = torch.bfloat16 - with fp8_model_init(enabled=True, recipe=DelayedScaling()): + with quantized_model_init(enabled=True, recipe=DelayedScaling()): model = MultiheadAttention( hidden_size=1024, num_attention_heads=16, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 231fa64bc..d2770347a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,8 +7,6 @@ from collections.abc import Iterable import io import math -import pathlib -import sys from typing import Optional import pytest @@ -17,7 +15,6 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, @@ -28,28 +25,27 @@ ForwardLinearBiasAdd, ForwardLinearScaleAdd, ) -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Tensor, +from transformer_engine.pytorch import ( + QuantizedTensor, Float8CurrentScalingQuantizer, Float8Quantizer, + MXFP8Quantizer, + NVFP4Quantizer, + is_bf16_available, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer -from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex # Import utility functions from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states # Check for supported quantization schemes -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher _dtypes.append(torch.bfloat16) # Supported devices @@ -372,7 +368,7 @@ def test_fp8_scale_update( ) # Construct model - with te.fp8_model_init(recipe=recipe): + with te.quantized_model_init(recipe=recipe): model = te_ops.basic.BasicLinear( size, size, @@ -404,7 +400,7 @@ def test_fp8_scale_update( ) # Training step - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): y = model(x) y.backward(dy) with torch.no_grad(): @@ -473,7 +469,7 @@ def test_dtype_cast( ) # Construct operation - with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): + with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) with torch.no_grad(): op.weight.copy_(w_test) @@ -530,7 +526,7 @@ def test_pyt_autocast( # Construct operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weights, recipe=recipe): op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) # Check forward and backward pass @@ -540,7 +536,7 @@ def test_pyt_autocast( device=device, requires_grad=True, ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): with torch.autocast(device_type=device.type, dtype=autocast_dtype): y = op(x) y.backward(torch.zeros_like(y)) @@ -553,7 +549,7 @@ def test_pyt_autocast( x.grad = None op.weight.grad = None with torch.autocast(device_type=device.type, dtype=autocast_dtype): - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = op(x) y.backward(torch.zeros_like(y)) assert y.dtype == autocast_dtype @@ -803,7 +799,7 @@ def test_quantize( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) recipe = make_recipe(quantization) - with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): + with te.autocast(enabled=with_quantization, recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -897,7 +893,7 @@ def _test_basic_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -914,7 +910,7 @@ def _test_basic_linear( op, te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -1075,7 +1071,7 @@ def test_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.Linear( in_features, out_features, @@ -1091,7 +1087,7 @@ def test_linear( del b_test for param in op.parameters(): param.requires_grad_(requires_grad=weight_requires_grad) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = op(x_test) if input_requires_grad or weight_requires_grad: y_test.backward(dy_test) @@ -1192,7 +1188,7 @@ def test_layer_norm( op, te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -1354,7 +1350,7 @@ def test_rmsnorm( op, te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -1654,7 +1650,7 @@ def test_activation( make_op(cache_quantized_input=cache_quantized_input), te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -1721,7 +1717,7 @@ def test_swiglu( te_ops.SwiGLU(), te_ops.Quantize(forward=quantize_forward, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -1792,7 +1788,7 @@ def test_clamped_swiglu( te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), te_ops.Quantize(forward=quantize_forward, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -2002,7 +1998,7 @@ def test_forward_linear_bias_activation( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): + with te.quantized_model_init(enabled=quantized_compute, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2018,7 +2014,7 @@ def test_forward_linear_bias_activation( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -2112,7 +2108,7 @@ def test_forward_linear_bias_add( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2129,7 +2125,7 @@ def test_forward_linear_bias_add( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -2218,7 +2214,7 @@ def test_forward_linear_scale_add( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2234,7 +2230,7 @@ def test_forward_linear_scale_add( with torch.no_grad(): model[0].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -2325,7 +2321,7 @@ def test_backward_activation_bias( with torch.no_grad(): model[1].bias.copy_(b_test) del b_test - with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): + with te.autocast(enabled=with_quantization, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -2503,7 +2499,7 @@ def test_backward_linear_add( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight): + with te.quantized_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.MakeExtraOutput(in_place=True), te_ops.Linear( @@ -2517,7 +2513,7 @@ def test_backward_linear_add( with torch.no_grad(): model[1].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y1_test, y2_test = model(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -2598,7 +2594,7 @@ def test_backward_linear_scale( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight): + with te.quantized_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2612,7 +2608,7 @@ def test_backward_linear_scale( with torch.no_grad(): model[0].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) (y_test * dy_test).sum().backward() @@ -2672,7 +2668,7 @@ def test_linear( # Construct model recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model_save = te_ops.Sequential( te_ops.Linear(in_features, out_features, device=device, dtype=dtype) ) @@ -2683,7 +2679,7 @@ def test_linear( x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) dy = torch.randn(out_shape, dtype=dtype, device=device) optim_save.zero_grad() - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = model_save(x) y.backward(dy) optim_save.step() @@ -2712,14 +2708,14 @@ def test_linear( ys_save = [] for i in range(post_checkpoint_steps): optim_save.zero_grad() - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = model_save(xs_save[i]) y.backward(dys[i]) optim_save.step() ys_save.append(y) # Load checkpoint - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model_load = te_ops.Sequential( te_ops.Linear(in_features, out_features, device=device, dtype=dtype) ) @@ -2732,7 +2728,7 @@ def test_linear( ys_load = [] for i in range(post_checkpoint_steps): optim_load.zero_grad() - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = model_load(xs_load[i]) y.backward(dys[i]) optim_load.step() @@ -2819,7 +2815,7 @@ def test_layernorm_mlp( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): if normalization == "LayerNorm": norm = te_ops.LayerNorm( hidden_size, @@ -2850,6 +2846,6 @@ def test_layernorm_mlp( dtype=dtype, ) forward = te_ops.Sequential(norm, ffn1, act, ffn2) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) diff --git a/tests/pytorch/test_hf_integration.py b/tests/pytorch/test_hf_integration.py index e74b16022..b014201c2 100644 --- a/tests/pytorch/test_hf_integration.py +++ b/tests/pytorch/test_hf_integration.py @@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch import TransformerLayer class SimpleTEModel(PreTrainedModel): diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index 6cbb3b9d2..46ba82187 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -5,7 +5,7 @@ import pytest import torch -import transformer_engine.pytorch as te +import transformer_engine.pytorch import transformer_engine_torch as tex from transformer_engine.pytorch.optimizers import MultiTensorApply diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a0e285b91..bef076a38 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -12,18 +12,15 @@ import torch.nn as nn from torch.nn import Parameter -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - fp8_autocast, - fp8_model_init, -) +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, attention_mask_func, - is_bf16_compatible, ) from transformer_engine.pytorch import ( + autocast, + quantized_model_init, DotProductAttention, LayerNormLinear, LayerNormMLP, @@ -35,26 +32,28 @@ LayerNorm, Fp8Padding, Fp8Unpadding, -) -from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, + get_device_compute_capability, + is_fp8_available, + is_mxfp8_available, + is_fp8_block_scaling_available, + is_bf16_available, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch import checkpoint as te_checkpoint +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace -from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states, get_available_attention_backends # Only run FP8 tests on supported devices. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) +fp8_block_scaling_available = is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -77,7 +76,7 @@ input_formats_inference = ["sbhd", "bshd"] param_types = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) batch_sizes = [1, 2] @@ -548,7 +547,7 @@ def _test_e2e_selective_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -575,7 +574,7 @@ def _test_e2e_selective_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -637,7 +636,7 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -665,7 +664,7 @@ def _test_e2e_full_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): if recompute: te_out = te_checkpoint( block, @@ -1088,7 +1087,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, ) inp_hidden_states.retain_grad() - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): out = block(inp_hidden_states) if isinstance(out, (List, Tuple)): out = out[0] @@ -1304,7 +1303,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): te_linear_ref = Linear( config.hidden_size, 4 * config.hidden_size, @@ -1758,7 +1757,7 @@ def _test_grouped_linear_accuracy( else: m_splits = torch.tensor([config.max_seqlen_q]) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): if isinstance(block, GroupedLinear): m_splits = m_splits * bs out = block(inp_hidden_states, m_splits.tolist()) @@ -1820,7 +1819,7 @@ def test_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -1956,7 +1955,7 @@ def test_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2110,7 +2109,7 @@ def _generate_random_numbers(n, total_sum): m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): out = block(inp_hidden_states, m_splits) else: @@ -2158,7 +2157,7 @@ def test_padding_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -2169,7 +2168,7 @@ def test_padding_grouped_linear_accuracy( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2229,7 +2228,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -2240,7 +2239,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2390,7 +2389,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -2417,7 +2416,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with autocast(enabled=True, recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index e5368497d..f8b4d7481 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -34,7 +34,7 @@ from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import get_default_init_method import tensorrt as trt @@ -57,8 +57,8 @@ # The directory where this file is stored. TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) fp8_recipes = [] if mxfp8_available: @@ -178,8 +178,8 @@ def do_export( input_names = input_names or ["input"] output_names = output_names or ["output"] - with torch.inference_mode(), te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + with torch.inference_mode(), te.autocast( + enabled=fp8_recipe is not None, recipe=fp8_recipe ), warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") @@ -233,8 +233,8 @@ def te_infer( fp8_recipe: recipe.Recipe, ): """Transformer Engine forward propagation.""" - with torch.inference_mode(), te.fp8_autocast( - enabled=is_fp8, fp8_recipe=fp8_recipe + with torch.inference_mode(), te.autocast( + enabled=is_fp8, recipe=fp8_recipe ), warnings.catch_warnings(): te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) if not isinstance(te_outputs, tuple): @@ -440,7 +440,7 @@ def forward(self, inp): bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( device="cuda" ) @@ -500,7 +500,7 @@ def _test_export_layernorm( fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx" with torch.no_grad(): - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm model = layernorm_cls( hidden_size, @@ -568,7 +568,7 @@ def _test_export_layernorm_linear( fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" with torch.no_grad(): - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): model = te.LayerNormLinear( hidden_size, 3 * hidden_size, @@ -654,7 +654,7 @@ def _test_export_layernorm_mlp( bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): model = te.LayerNormMLP( hidden_size, ffn_hidden_size, @@ -1160,13 +1160,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): out_ref = model(*inps) onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(onnx_fd) try: - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): with te.onnx_export(enabled=True): torch.onnx.export( model, diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index fa56852ff..e325146b7 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -4,7 +4,7 @@ import random import torch -from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch import parallel_cross_entropy from utils import dtype_tols diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index e1cd48e18..e8a7bedc8 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -8,6 +8,7 @@ import pytest from typing import Dict, List +import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch import ( moe_permute as te_permute, @@ -16,14 +17,12 @@ moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, ) -from transformer_engine.pytorch.utils import is_bf16_compatible -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( Float8Quantizer, Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, ) -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex import copy @@ -1119,7 +1118,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): # TE tensor dtypes _te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] -if is_bf16_compatible(): +if te.is_bf16_available(): _te_dtypes.append(tex.DType.kBFloat16) @@ -1239,10 +1238,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): # Only run FP8 tests on H100. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) fp8_recipes = [ recipe.MXFP8BlockScaling(), diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 004abfd97..71032d23f 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Iterable, Optional +from typing import Optional import pytest import torch @@ -10,28 +10,34 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch import ( + Float8BlockQuantizer, + MXFP8Quantizer, + Float8Quantizer, + NVFP4Quantizer, + quantized_model_init, + Linear, + LayerNormLinear, + LayerNormMLP, + GroupedLinear, +) + import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import ( +from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, _amax_and_scale_update, - fp8_model_init, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer -from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear -from transformer_engine.pytorch.distributed import fp8_autocast from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) +fp4_available, reason_for_no_fp4 = te.is_nvfp4_available(return_reason=True) # FP8 per tensor delayed scaling @@ -64,7 +70,7 @@ def test_fp8_scale_update_with_linear_module( amax_history_len=amax_history_len, amax_compute_algo=amax_compute_algo, ) - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): module = te.Linear(16, 16) y = module( torch.randn([16, 16], device="cuda"), @@ -120,7 +126,7 @@ def test_fp8_scale_update_with_linear_module( # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Perform forward, backward, and optimizer steps to update fp8_meta - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.autocast(enabled=True, recipe=recipe): x = torch.randn([16, 16], device="cuda") y = module(x, is_first_microbatch=is_first_microbatch) y.backward(torch.randn_like(y)) @@ -219,7 +225,7 @@ def test_fp8_scale_update_with_linear_fuser_op( op.weight.fill_(w_history[-1]) # Forward and backward pass - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): y = op(x) y.backward(dy) @@ -301,7 +307,7 @@ def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype scaling_factor_compute_algo = None if fused_update: scaling_factor_compute_algo = ( - lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute( + lambda amax, scale, fp8_max, recipe: te.quantization._default_sf_compute( amax, scale, fp8_max, recipe.margin ) ) @@ -311,7 +317,7 @@ def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype # Setup fp8_meta dictionary def setup_fp8_meta(): - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): module = te.Linear(16, 16) y = module(torch.zeros([16, 16], device="cuda")) y.backward(torch.zeros_like(y)) @@ -393,11 +399,11 @@ def setup_fp8_meta(): ], ) def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): - with fp8_model_init(enabled=True, recipe=model_init_recipe): + with quantized_model_init(enabled=True, recipe=model_init_recipe): linear = Linear(32, 32).cuda() x = torch.randn(32, 32, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): + with te.autocast(enabled=True, recipe=DelayedScaling()): with pytest.raises(RuntimeError) as excinfo: _ = linear(x) assert "Recipe mismatch for " in str(excinfo.value) @@ -436,7 +442,7 @@ def test_dynamic_recipe_update( # Run initial iterations with DelayedScaling for _ in range(3): x = torch.randn(batch_size, in_features, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): + with te.autocast(enabled=True, recipe=initial_recipe): y = linear(x) loss = y.mean() loss.backward() @@ -453,7 +459,7 @@ def test_dynamic_recipe_update( if i == 0: # Expect a warning on the first iteration with the new recipe with pytest.warns(UserWarning, match="Recipe type changed"): - with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + with te.autocast(enabled=True, recipe=target_recipe): y = linear(x) for quantizer in linear.quantizers["scaling_fwd"]: assert isinstance(quantizer, expected_quantizer_type) @@ -461,7 +467,7 @@ def test_dynamic_recipe_update( # No warning expected on subsequent iterations with warnings.catch_warnings(): warnings.simplefilter("error") # Raise error if unexpected warning occurs - with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + with te.autocast(enabled=True, recipe=target_recipe): y = linear(x) loss = y.mean() loss.backward() @@ -485,7 +491,7 @@ def test_quantizer_update(self, module_class): batch_size = 32 recipe = DelayedScaling(amax_history_len=1024) - with fp8_model_init(recipe=recipe): + with quantized_model_init(recipe=recipe): if module_class == GroupedLinear: module = module_class(1, in_features, out_features).cuda() else: @@ -493,7 +499,7 @@ def test_quantizer_update(self, module_class): x = torch.randn(batch_size, in_features, device="cuda") recipe = DelayedScaling(amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.autocast(enabled=True, recipe=recipe): warn_msg = "Quantizer is being updated, this may affect model behavior" with pytest.warns(UserWarning, match=warn_msg): if module_class == GroupedLinear: @@ -502,9 +508,6 @@ def test_quantizer_update(self, module_class): y = module(x) -fp4_available, reason_for_no_fp4 = FP8GlobalStateManager.is_nvfp4_available() - - @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize( diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 981c58243..e283842ec 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -8,18 +8,16 @@ import pytest import os -import transformer_engine.pytorch -from transformer_engine.pytorch.fp8 import ( - fp8_autocast, - FP8GlobalStateManager, - fp8_model_init, -) +import transformer_engine +import transformer_engine.pytorch as te +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, - is_bf16_compatible, ) from transformer_engine.pytorch import ( + autocast, + quantized_model_init, LayerNormLinear, Linear, GroupedLinear, @@ -27,26 +25,25 @@ TransformerLayer, RMSNorm, LayerNorm, + Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, + MXFP8Tensor, + checkpoint, + QuantizedTensor, + is_bf16_available, ) from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8CurrentScalingQuantizer, - Float8Quantizer, - Float8Tensor, -) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data -from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig # Only run FP8 tests on supported devices. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -108,7 +105,7 @@ def nvfp4_vanilla(): fp8_recipes.append(None) param_types = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) all_boolean = [True, False] @@ -160,7 +157,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): use_fp8 = fp8_recipe is not None with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() @@ -199,7 +196,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci p.main_grad = torch.zeros_like(p) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -227,7 +224,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states) loss = te_out.sum() loss.backward() @@ -253,7 +250,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -285,7 +282,7 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -314,7 +311,7 @@ def _test_sanity_common( _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): if not microbatching: te_out = block(te_inp) else: @@ -455,7 +452,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): + with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_linear = Linear( config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -463,7 +460,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): out = te_linear(inp_hidden_states) loss = out.sum() loss.backward() @@ -496,7 +493,7 @@ def test_sanity_grouped_linear( pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): + with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -512,7 +509,7 @@ def test_sanity_grouped_linear( elif empty_split == "middle": m_splits[num_gemms // 2] = 0 - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): out = te_grouped_linear(inp_hidden_states, m_splits) loss = out.sum() loss.backward() @@ -976,9 +973,9 @@ def test_replace_raw_data_for_float8tensor(): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -def test_fp8_model_init_high_precision_init_val(): - """Test fp8_model_init with preserve_high_precision_init_val=True""" - with fp8_model_init(preserve_high_precision_init_val=True): +def test_quantized_model_init_high_precision_init_val(): + """Test quantized_model_init with preserve_high_precision_init_val=True""" + with quantized_model_init(preserve_high_precision_init_val=True): model = Linear(768, 768) weight = model.weight @@ -1051,7 +1048,7 @@ def test_linear_frozen_weights_memory_default_recipe(): linear.weight.requires_grad = False # Forward and backward pass with FP8 - with fp8_autocast(): + with autocast(): o = linear(x) g_o = torch.randn_like(o) @@ -1105,7 +1102,7 @@ def test_inference_mode( # Construct module module = None with torch.no_grad(): - with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): + with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe): if module_name == "Linear": module = Linear(hidden_size, hidden_size) elif module_name == "LayerNormLinear": @@ -1140,6 +1137,6 @@ def check_weights(): kwargs = {} if module_name == "GroupedLinear": kwargs["m_splits"] = [sequence_length] - with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe): + with autocast(enabled=with_quantization, recipe=quantization_recipe): y = module(x, **kwargs) check_weights() diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index d77256b7f..72a1b3b53 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -7,14 +7,14 @@ import logging import os from contextlib import contextmanager +from typing import Optional, Tuple, Dict, Any, List -import pytest import torch import transformer_engine -import transformer_engine.common.recipe -import transformer_engine.pytorch as te import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import InferenceParams from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( get_attention_backend, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index f70b43a7a..7bc39f074 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -161,7 +161,7 @@ def scaling_factor_compute(amax: Tensor, where `Tensor` is a framework tensor type. reduce_amax: bool, default = `True` By default, if `torch.distributed` is initialized, the `amax` value for FP8 - tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` + tensors is reduced across the `amax_reduction_group` (specified in the `autocast` call). This keeps the amaxes and scaling factors synced across the given distributed group. If set to `False`, this reduction is skipped and every GPU maintains local amaxes and scaling factors. To ensure results are @@ -169,7 +169,7 @@ def scaling_factor_compute(amax: Tensor, ranks must checkpoint in order to store the local tensors. fp8_dpa: bool, default = `False` Whether to enable FP8 dot product attention (DPA). When the model is placed in an - `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + `autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the inputs from higher precision to FP8, performs attention in FP8, and casts tensors back to higher precision as outputs. FP8 DPA currently is only supported in the `FusedAttention` backend. diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index 4b01f9712..58c7379b5 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.quantization import _default_sf_compute def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 354a1293e..6259a7ad8 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,7 +34,7 @@ from . import flax from . import quantize -from .quantize import fp8_autocast, update_collections +from .quantize import autocast, fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME from .sharding import MeshResource @@ -45,6 +45,7 @@ __all__ = [ "NVTE_FP8_COLLECTION_NAME", + "autocast", "fp8_autocast", "update_collections", "MeshResource", diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index c95765bf3..1eafed413 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -66,7 +66,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: for 1D-sharding tensor parallelism. .. warning:: - Please make sure ShardingResource is set via fp8_autocast before calling this function. + Please make sure ShardingResource is set via autocast before calling this function. .. note:: This function is only needed when using TransformerLayer. For other modules, such as diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 70611cbea..06c67b62e 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -7,6 +7,7 @@ This module provides configuration and helper functions for managing quantization metadata in JAX, including support for different scaling modes and datatypes. """ + from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass @@ -23,7 +24,14 @@ from flax.core.frozen_dict import FrozenDict from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version -from transformer_engine.common import recipe +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, + NVFP4BlockScaling, +) from transformer_engine.jax.sharding import ( global_shard_guard, MeshResource, @@ -39,6 +47,7 @@ __all__ = [ "get_quantize_config", "get_quantize_config_with_recipe", + "autocast", "fp8_autocast", "is_fp8_available", "is_scaling_mode_supported", @@ -51,8 +60,6 @@ "TensorSource", ] -_is_fp8_available = None -_reason_for_no_fp8 = "" _is_scaling_mode_supported = None _reason_for_no_scaling_mode = "" Collection = Union[Dict, FrozenDict] @@ -195,22 +202,22 @@ def get_supported_scaling_modes() -> List[ScalingMode]: ] -def get_supported_quantization_recipes() -> List[recipe.Recipe]: +def get_supported_quantization_recipes() -> List[Recipe]: """Get all supported quantization recipes.""" # We don't support all the recipes TE/Common supports yet # return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()] all_recipes = [ - recipe.DelayedScaling(), - recipe.Float8CurrentScaling(), - recipe.MXFP8BlockScaling(), - recipe.NVFP4BlockScaling(), + DelayedScaling(), + Float8CurrentScaling(), + MXFP8BlockScaling(), + NVFP4BlockScaling(), ] return [ recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0] ] -def _format2dtypes(format_: recipe.Format): +def _format2dtypes(format_: Format): """Convert recipe.Format.dtype to corresponding JAX dtypes. Args: @@ -219,13 +226,13 @@ def _format2dtypes(format_: recipe.Format): Returns: A tuple of (forward_dtype, backward_dtype) for the given format """ - if format_ == recipe.Format.E4M3: + if format_ == Format.E4M3: return jnp.float8_e4m3fn, jnp.float8_e4m3fn - if format_ == recipe.Format.E5M2: + if format_ == Format.E5M2: return jnp.float8_e5m2, jnp.float8_e5m2 - if format_ == recipe.Format.HYBRID: + if format_ == Format.HYBRID: return jnp.float8_e4m3fn, jnp.float8_e5m2 - if format_ == recipe.Format.E2M1: + if format_ == Format.E2M1: return jnp.float4_e2m1fn, jnp.float4_e2m1fn return jnp.bfloat16, jnp.bfloat16 @@ -289,7 +296,7 @@ class BaseQuantizeConfig(ABC): AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize the quantization configuration from a given recipe. Args: @@ -359,7 +366,7 @@ def is_supported(self) -> tuple[bool, str]: class NoOpQuantizeConfig(BaseQuantizeConfig): """Configuration class higher-precision non-quantized operation.""" - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize no-op configuration.""" raise NotImplementedError( "NoOpQuantizeConfig cannot be initialize from a recipe as it represents" @@ -399,7 +406,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): FP8 quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize delayed scaling FP8 configuration. Args: @@ -477,7 +484,7 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig): FP8 quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize current scaling FP8 configuration. Args: @@ -519,7 +526,7 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): FP8 quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize block scaling FP8 configuration. Args: @@ -560,7 +567,7 @@ class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): This class provides specific initialization and finalization for NVFP4 scaling quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize block scaling FP8 configuration. Args: @@ -622,12 +629,12 @@ def get_quantize_flax_meta( def get_quantize_config(): - """Global instance of BaseQuantizeConfig set by fp8_autocast context.""" + """Global instance of BaseQuantizeConfig set by autocast context.""" return _QUANTIZE_CONFIG def get_quantize_config_class( - fp8_recipe: recipe.Recipe, + fp8_recipe: Recipe, ) -> Type[BaseQuantizeConfig]: """Get the quantization configuration class based on the FP8 recipe. @@ -636,18 +643,18 @@ def get_quantize_config_class( Returns: The quantization config class corresponding to the given recipe. """ - if isinstance(fp8_recipe, recipe.DelayedScaling): + if isinstance(fp8_recipe, DelayedScaling): return DelayedScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + if isinstance(fp8_recipe, MXFP8BlockScaling): return BlockScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + if isinstance(fp8_recipe, Float8CurrentScaling): return CurrentScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.NVFP4BlockScaling): + if isinstance(fp8_recipe, NVFP4BlockScaling): return NVFP4ScalingQuantizeConfig raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") -def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe): +def get_quantize_config_with_recipe(fp8_recipe: Recipe): """Get the quantization configuration object based on the FP8 recipe.""" config = get_quantize_config_class(fp8_recipe)() config.initialize_from_recipe(fp8_recipe) @@ -655,14 +662,14 @@ def get_quantize_config_with_recipe(fp8_recipe: recipe.Recipe): @contextmanager -def fp8_autocast( +def autocast( enabled: bool = False, - fp8_recipe: Optional[recipe.Recipe] = None, + recipe: Optional[Recipe] = None, mesh_resource: Optional[MeshResource] = None, ) -> None: - r"""Context manager for FP8 automatic mixed precision. + r"""Context manager for FP8 or FP4 usage. - This context manager enables FP8 quantization for the duration of its context. + This context manager enables quantization for the duration of its context. .. code-block:: python mesh_shape = (4, 2) @@ -673,7 +680,7 @@ def fp8_autocast( with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) - with fp8_autocast(enabled=True, mesh_resource=mesh_resource): + with autocast(enabled=True, mesh_resource=mesh_resource): rules = extend_logical_axis_rules(tuple()) transformer = TransformerLayer() @@ -690,15 +697,15 @@ def fp8_autocast( ---------- enabled: bool, default = False Whether or not to enable fp8 - fp8_recipe: recipe.DelayedScaling, default = None - Recipe used for FP8 training. + recipe: recipe.DelayedScaling, default = None + recipe used for low precision quantization. mesh_resource: MeshResource, default = None Specify the mesh axes for data and tensor parallelism to shard along. If set to None, then no data or tensor parallelism will be used. """ - if fp8_recipe is None: - fp8_recipe = recipe.DelayedScaling() + if recipe is None: + recipe = DelayedScaling() global _QUANTIZE_CONFIG @@ -709,15 +716,45 @@ def fp8_autocast( try: with global_shard_guard(mesh_resource): if enabled: - _QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() + _QUANTIZE_CONFIG = get_quantize_config_class(recipe)() is_supported, reason = _QUANTIZE_CONFIG.is_supported() assert is_supported, reason - _QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) + _QUANTIZE_CONFIG.initialize_from_recipe(recipe) yield finally: _QUANTIZE_CONFIG = old_quantize_config +@contextmanager +def fp8_autocast( + enabled: bool = False, + fp8_recipe: Optional[Recipe] = None, + mesh_resource: Optional[MeshResource] = None, +) -> None: + """ + .. warning:: + + fp8_autocast is deprecated and will be removed in a future release. + Use autocast(enabled=..., recipe=..., mesh_resource=...) instead. + + """ + + warnings.warn( + "fp8_autocast is deprecated and will be removed in a future release. " + "Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + # Call new implementation. + with autocast( + enabled=enabled, + recipe=fp8_recipe, + mesh_resource=mesh_resource, + ): + yield + + def update_collections(new: Collection, original: Collection) -> Collection: r"""Update collections with new values while preserving original structure. diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 3256512b5..77c71b811 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -46,8 +46,18 @@ def torch_version() -> tuple[int, ...]: moe_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs, ) -from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.fp8 import fp8_model_init +from transformer_engine.pytorch.quantization import fp8_autocast +from transformer_engine.pytorch.quantization import fp8_model_init +from transformer_engine.pytorch.quantization import autocast +from transformer_engine.pytorch.quantization import quantized_model_init +from transformer_engine.pytorch.quantization import is_fp8_available +from transformer_engine.pytorch.quantization import is_mxfp8_available +from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available +from transformer_engine.pytorch.quantization import is_nvfp4_available +from transformer_engine.pytorch.quantization import get_default_recipe +from transformer_engine.pytorch.utils import get_cudnn_version +from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.utils import is_bf16_available from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker @@ -61,14 +71,17 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor import NVFP4Quantizer from transformer_engine.pytorch.tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor import Float8TensorStorage from transformer_engine.pytorch.tensor import MXFP8TensorStorage from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage +from transformer_engine.pytorch.tensor import NVFP4TensorStorage from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import Float8Tensor from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor import NVFP4Tensor from transformer_engine.pytorch.tensor import prepare_for_saving from transformer_engine.pytorch.tensor import restore_from_saved diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d75481ad9..6dfe0d31b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -42,7 +42,7 @@ META_O, META_QKV, ) -from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype, FP8GlobalStateManager +from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( @@ -1074,7 +1074,7 @@ def forward( nvtx_label = "transformer_engine.FusedAttnFunc.forward" nvtx_range_push(f"{nvtx_label}") - # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index d1374e949..a474cb809 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -19,7 +19,7 @@ fused_attn_bwd, FusedAttnBackend, ) -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser @@ -1164,7 +1164,7 @@ def forward( is_input_fp8 = isinstance(q, Float8Tensor) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: @@ -3151,7 +3151,7 @@ def forward( is_input_fp8 = isinstance(q, Float8Tensor) is_output_fp8 = fp8_output is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) - # recipe passed in through fp8_autocast or set by NVTE_DPA_FP8_RECIPE; + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 0a8802fb0..6d9ce9a52 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -21,7 +21,7 @@ Float8CurrentScaling, ) from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.fp8 import ( +from transformer_engine.pytorch.quantization import ( get_fp8_te_dtype, FP8GlobalStateManager, RecipeState, @@ -91,26 +91,26 @@ This feature is **experimental** and subject to change. Some models may use different FP8 recipes for their linear layers and attention layers. To support this, -users can either use multiple, nested fp8_autocast() contexts to assign a distinct recipe for each layer, -or use a single fp8_autocast() for the non-attention layers and configure the recipe for the attention +users can either use multiple, nested autocast() contexts to assign a distinct recipe for each layer, +or use a single autocast() for the non-attention layers and configure the recipe for the attention layers as follows. +-------------------+-----------+-----------------------------------------------------------------------------------+ | Linear | Attention | Configuration | +===================+===========+===================================================================================+ -| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to fp8_autocast(); | +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | | | | export NVTE_DPA_FP8_RECIPE="F16" | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8DS | Pass FP8DS to fp8_autocast(); | +| FP8DS | FP8DS | Pass FP8DS to autocast(); | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8DS | Pass FP8CS to fp8_autocast(); | +| FP8CS | FP8DS | Pass FP8CS to autocast(); | | | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8DS | Pass NVFP4 to fp8_autocast(); | +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | | | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | | | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | | | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | @@ -118,19 +118,19 @@ | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8DS | FP8CS | Pass FP8DS to fp8_autocast(); | +| FP8DS | FP8CS | Pass FP8DS to autocast(); | | | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| | | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| FP8CS | FP8CS | Pass FP8CS to fp8_autocast(); | +| FP8CS | FP8CS | Pass FP8CS to autocast(); | | | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | | | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | | | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | | | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | | | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | +-------------------+-----------+-----------------------------------------------------------------------------------+ -| NVFP4 | FP8CS | Pass NVFP4 to fp8_autocast(); | +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | | | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | | | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | | | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | @@ -544,7 +544,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: """ _original_recipe = self.fp8_meta.get("recipe", None) - # global recipe set in fp8_autocast() + # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_recipe.custom(): return @@ -560,7 +560,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: fp8_recipe_dpa = fp8_recipe fp8_recipes = fp8_recipe if _dpa_fp8_recipe == "F16": - # ignore the recipe from fp8_autocast, set fp8_dpa = False, fp8_mha = False + # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False fp8_recipe.fp8_dpa = False fp8_recipe.fp8_mha = False elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 8b26a1760..c8cc3d29f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -40,7 +40,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType @@ -222,7 +222,7 @@ class AttentionParams: is_training: bool, default = `True` Whether in training mode (`True`) or inference mode (`False`) fp8: bool, default = `False` - Whether `DotProductAttention` is in an `fp8_autocast` region. + Whether `DotProductAttention` is in an `autocast` region. fp8_meta: Optional[Dict[str Any]], default = `None` The FP8 metadata tensor of `DotProductAttention`. inference_params: Optional[InferenceParams], default = `None` diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 8f0183224..b3bda677b 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -9,7 +9,7 @@ import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm @@ -33,7 +33,7 @@ from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb -# Force DotProductAttention to use a different recipe than the fp8_recipe set in fp8_autocast(). +# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast(). # Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" # and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 51fbb50c4..5ed73f678 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -36,7 +36,7 @@ needs_quantized_gemm, ) from .constants import dist_group_type -from .fp8 import FP8GlobalStateManager, fp8_autocast +from .quantization import FP8GlobalStateManager, autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer @@ -419,8 +419,8 @@ def backward( detached_inputs = detach_variable(inputs) with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( activation_recompute=True, recompute_phase=True - ), fp8_autocast( - enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe + ), autocast( + enabled=ctx.fp8, recipe=ctx.fp8_recipe ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) @@ -754,8 +754,8 @@ def checkpoint( def recompute_fn(*args, **kwargs): with torch.autograd.enable_grad(), ( te_recompute_ctx - ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast( - enabled=fp8, fp8_recipe=fp8_recipe + ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, autocast( + enabled=fp8, recipe=fp8_recipe ): function(*args, **kwargs) @@ -1969,7 +1969,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: if hasattr(fsdp_root, "primary_weights_in_fp8"): assert not fsdp_root.primary_weights_in_fp8, ( "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.fp8_model_init(...) context." + "Please initialize your model without the te.quantized_model_init(...) context." ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." @@ -1982,7 +1982,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: if hasattr(fsdp_module.module, "primary_weights_in_fp8"): assert not fsdp_module.module.primary_weights_in_fp8, ( "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.fp8_model_init(...) context." + "Please initialize your model without the te.quantized_model_init(...) context." ) setattr(fsdp_module.module, "fsdp_group", state.process_group) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index bfe241f81..f937b3de9 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -2,18 +2,26 @@ # # See LICENSE for license information. -"""FP8 utilities for TransformerEngine""" -from __future__ import annotations +""" +DEPRECATED in favor of `transformer_engine.pytorch.quantization.py`. +""" -import abc -import itertools -import os -from contextlib import contextmanager -from collections import deque -from typing import Callable, List, Optional, Dict, Any, Tuple, Union +# pylint: disable=wrong-import-position,unused-import -import torch -import transformer_engine_torch as tex +import warnings + +warnings.warn( + "Using deprecated internal API from Transformer Engine. " + "transformer_engine.pytorch.fp8 will be removed in a " + "future release.", + DeprecationWarning, + stacklevel=2, +) + + +# There are some users indirectly importing these classes +# from fp8.py. This ensure backwards compatibility. +# https://github.com/Lightning-AI/lightning-thunder/pull/2635. from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -25,1224 +33,36 @@ CustomRecipe, ) -from .constants import dist_group_type -from .utils import get_device_compute_capability -from .jit import jit_fuser - - -__all__ = ["fp8_autocast", "fp8_model_init"] - - -def check_fp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - if get_device_compute_capability() >= (9, 0): # hopper and above - return True, "" - if get_device_compute_capability() < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if float(torch.version.cuda) < 12.1: - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - - -def check_mxfp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." - if get_device_compute_capability() >= (10, 0): # blackwell and above - return True, "" - return False, "Device compute capability 10.0 or higher required for MXFP8 execution." - - -def check_nvfp4_support() -> Tuple[bool, str]: - """Return if nvfp4 support is available""" - if get_device_compute_capability() >= (10, 0): # blackwell and above - return True, "" - return False, "Device compute capability 10.0 or higher required for NVFP4 execution." - - -def check_fp8_block_scaling_support() -> Tuple[bool, str]: - """Return if fp8 block scaling support is available""" - if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: - return True, "" - return ( - False, - "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", - ) - - -def check_recipe_support(recipe: Recipe) -> None: - """Check if the given recipe is supported.""" - recipe_supported = True - unsupported_reason = "" - if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): - recipe_supported, unsupported_reason = check_fp8_support() - elif isinstance(recipe, Float8BlockScaling): - recipe_supported, unsupported_reason = check_fp8_block_scaling_support() - elif isinstance(recipe, MXFP8BlockScaling): - recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason - - -def get_default_fp8_recipe() -> Recipe: - """FP8 recipe with default args.""" - if check_mxfp8_support()[0]: - return MXFP8BlockScaling() - if get_device_compute_capability() >= (12, 0): - # This is a temporary restriction until MXFP8 is supported for all gemm layouts. - return Float8CurrentScaling() - return DelayedScaling() - - -def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return torch.float8_e4m3fn - return torch.float8_e5m2 - - -def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: - """Get fp4 data type according to recipe and tensor""" - if fp4_recipe.fp4_format == Format.E2M1: - return tex.DType.kFloat4E2M1 - raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") - - -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get max representible FP8 value.""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return Format.E4M3.value.max_fwd - return Format.E5M2.value.max_fwd - - -class FP8GlobalStateManager: - """Class to keep track of and manipulate the global - FP8 state at different stages of execution. - """ - - FP8_ENABLED = False - FP8_CALIBRATION = False - FP8_RECIPE = None - FP8_DISTRIBUTED_GROUP = None - FP8_PARAMETERS = False - HIGH_PRECISION_INIT_VAL = False - IS_FIRST_FP8_MODULE = False - FP8_GRAPH_CAPTURING = False - FP8_AUTOCAST_DEPTH = 0 - global_amax_buffer = {} - global_amax_history_buffer = {} - global_scale_buffer = {} - fp8_tensors_recompute_buffer = [] - fp8_available = None - reason_for_no_fp8 = "" - autocast_arguments = {} - autocast_to_fp8_params = {} - fp8_param_to_autocast = {} - skip_fp8_weight_update_tensor = None - mxfp8_available = None - reason_for_no_mxfp8 = "" - fp8_block_scaling_available = None - reason_for_no_fp8_block_scaling = None - nvfp4_available = None - reason_for_no_nvfp4 = "" - - @classmethod - def reset(cls) -> None: - """Reset the global state""" - cls.FP8_ENABLED = False - cls.FP8_CALIBRATION = False - cls.FP8_RECIPE = None - cls.FP8_DISTRIBUTED_GROUP = None - cls.FP8_PARAMETERS = False - cls.HIGH_PRECISION_INIT_VAL = False - cls.IS_FIRST_FP8_MODULE = False - cls.FP8_GRAPH_CAPTURING = False - cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_amax_buffer = {} - cls.global_amax_history_buffer = {} - cls.global_scale_buffer = {} - cls.fp8_tensors_recompute_buffer = [] - cls.fp8_available = None - cls.reason_for_no_fp8 = "" - cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} - cls.skip_fp8_weight_update_tensor = None - cls.mxfp8_available = None - cls.reason_for_no_mxfp8 = "" - cls.fp8_block_scaling_available = None - cls.reason_for_no_fp8_block_scaling = "" - - @classmethod - def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: - """`skip_fp8_weight_update_tensor` inplace setter.""" - if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") - cls.skip_fp8_weight_update_tensor.fill_(skip) - - @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> None: - """`skip_fp8_weight_update_tensor` getter.""" - return cls.skip_fp8_weight_update_tensor - - @classmethod - def is_fp8_available(cls) -> Tuple[bool, str]: - """Return if fp8 support is available""" - if cls.fp8_available is None: - cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() - return cls.fp8_available, cls.reason_for_no_fp8 - - @classmethod - def is_mxfp8_available(cls) -> Tuple[bool, str]: - """Return if MXFP8/current scaling support is available.""" - if cls.mxfp8_available is None: - cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() - return cls.mxfp8_available, cls.reason_for_no_mxfp8 - - @classmethod - def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: - """Return if Float8 block scaling support is available.""" - if cls.fp8_block_scaling_available is None: - cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( - check_fp8_block_scaling_support() - ) - return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling - - @classmethod - def is_nvfp4_available(cls) -> Tuple[bool, str]: - """Return if NVFP4 support is available.""" - if cls.nvfp4_available is None: - cls.nvfp4_available, cls.reason_for_no_nvfp4 = check_nvfp4_support() - return cls.nvfp4_available, cls.reason_for_no_nvfp4 - - @staticmethod - def get_meta_tensor_key(forward: bool = True) -> str: - """Returns scaling key in `fp8_meta`.""" - if forward: - return "scaling_fwd" - return "scaling_bwd" - - @staticmethod - def get_fwd_bwd_key(forward: bool = True) -> str: - """Convert bool `forward` to string.""" - return "forward" if forward else "backward" - - @classmethod - def get_buffer_info(cls) -> str: - """ - Returns a key for `fp8_meta` that stores the module's index - in the global buffers along with autocast information. - """ - return "buffer_index_and_autocast_key" - - @classmethod - def get_key_in_buffer( - cls, - forward: bool, - fp8_recipe: Recipe, - fp8_group: dist_group_type, - ) -> str: - """Returns a key into the global FP8 buffers.""" - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{autocast_key}" - - @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: - """Splits buffer key into relevant parts.""" - forward, autocast_key = key.split("_", 1) - forward = forward == "forward" - return forward, autocast_key - - @classmethod - def add_fp8_tensors_to_global_buffer( - cls, - fp8_meta: Dict[str, Any], - ) -> None: - """ - Delayed scaling only. - - The amax reduction process happens completely outside the FP8 modules. - To participate in the reduction, the only role played by a module is - to call this function in order to append it's FP8 tensor into a global - buffer. There are 5 global buffers maintained, one each for amax, amax - history, scale, scale-inverse, and non-weight-mask. Each buffer has - keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix - to indicate the type of FP8 tensor, since the forward and backward - reductions happen separately. - - Note: For CG capture, this method is called from the graphed - wrapper. For non CG case, it's called from within the module. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Every module must call this function exactly once since - # the amax tensors are static. Ensures that compatibility - # with non-graphed modules is maintained. - index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. - if index_in_buffer in fp8_meta: - return - - fp8_meta[index_in_buffer] = [] - for forward in (True, False): - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - if fp8_meta_tensor_key not in fp8_meta: - # Handles non-parameter FP8 modules, e.g. DPA. - continue - - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - - if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history - ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) - fp8_meta[index_in_buffer].append(key) - - @classmethod - def is_fp8_enabled(cls) -> bool: - """Is FP8 enabled""" - return cls.FP8_ENABLED - - @classmethod - def is_fp8_calibration(cls) -> bool: - """Is FP8 calibration""" - return cls.FP8_CALIBRATION - - @classmethod - def with_fp8_parameters(cls) -> bool: - """Should the parameters be stored as FP8""" - return cls.FP8_PARAMETERS - - @classmethod - def with_high_precision_init_val(cls) -> bool: - """Should the high precision initial values be stored with FP8 parameters""" - return cls.HIGH_PRECISION_INIT_VAL - - @classmethod - def fp8_graph_capturing(cls) -> bool: - """Is CUDA graph capture under way?""" - return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() - - @classmethod - def is_first_fp8_module(cls): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = cls.IS_FIRST_FP8_MODULE - cls.IS_FIRST_FP8_MODULE = False - return tmp - - @classmethod - def get_fp8_recipe(cls) -> Recipe: - """Return the fp8 recipe""" - if cls.FP8_RECIPE is not None: - return cls.FP8_RECIPE - return get_default_fp8_recipe() - - @classmethod - def get_fp8_group(cls) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return cls.FP8_DISTRIBUTED_GROUP - - @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: - """FP8 autocast state getter""" - return ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) - - @classmethod - def set_fp8_autocast_state( - cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] - ) -> None: - """FP8 autocast state setter""" - ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) = fp8_state - - @staticmethod - def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: - """Reduce tensor across given group.""" - if torch.distributed.is_initialized(): - torch.distributed.all_reduce( - tensor, - op=torch.distributed.ReduceOp.MAX, - group=group, - async_op=False, - ) - - @classmethod - def reduce_and_update_fp8_tensors( - cls, - forward: bool = True, - ) -> None: - """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" - # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in cls.global_amax_buffer.items(): - # Check for forward or backward reduction. - fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) - if fwd_update != forward: - continue - if len(amax_buffer) == 0: - continue - - # Retrieve autocast specific args and concat amaxes. - recipe, group = cls.autocast_arguments[autocast_key] - contiguous_amax = torch.cat(amax_buffer) - - # Reduction. - if ( - recipe.reduce_amax - and torch.distributed.is_initialized() - and torch.distributed.get_world_size(group=group) > 1 - ): - cls.reduce_tensor_across_group_op_max(contiguous_amax, group) - - # Amax and scale update. - unfused_update = ( - bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) - or callable(recipe.amax_compute_algo) - or callable(recipe.scaling_factor_compute_algo) - ) - - if not unfused_update: - tex.fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - recipe.amax_compute_algo, - get_fp8_te_dtype(recipe, forward), - recipe.margin, - ) - else: - split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - - for amax_history, scale in zip( - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - ): - _amax_and_scale_update( - amax_history, scale, get_fp8_max(recipe, forward), recipe - ) - - @classmethod - def get_unique_autocast_key( - cls, - recipe: Optional[Recipe] = None, - group: Optional[dist_group_type] = None, - ): - """ - For FP8, each autocast can be uniquely identified by the recipe and fp8 group. - Safely using `hash` as we never cross checkpoint boundaries. - """ - return f"{str(recipe)}:{hash(group)}" - - @classmethod - def fp8_autocast_enter( - cls, - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, - ) -> None: - """Set state and tracking variables for entry into FP8 region.""" - - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - - cls.FP8_ENABLED = enabled - cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = fp8_recipe - cls.FP8_DISTRIBUTED_GROUP = fp8_group - cls.FP8_GRAPH_CAPTURING = _graph - - if cls.FP8_AUTOCAST_DEPTH == 0: - cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_DEPTH += 1 - - if enabled: - fp8_available, reason_for_no_fp8 = cls.is_fp8_available() - assert fp8_available, reason_for_no_fp8 - if isinstance(fp8_recipe, MXFP8BlockScaling): - mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() - assert mxfp8_available, reason_for_no_mxfp8 - if isinstance(fp8_recipe, Float8BlockScaling): - fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() - assert fp8_block_available, reason_for_no_fp8_block - if isinstance(fp8_recipe, NVFP4BlockScaling): - nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() - assert nvfp4_available, reason_for_no_nvfp4 - - @classmethod - def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: - """Set state and tracking variables for exit from FP8 region.""" - cls.FP8_AUTOCAST_DEPTH -= 1 - # Reduce only the non-FP8 weight modules here. - # FP8 weight modules are reduced at the end of the optimizer - # step after the weight amax is populated. - if enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - # delayed scaling only function, for other recipes (current scaling with any granularity), - # this is noop for other recipes because cls.global_amax_buffer is empty list - cls.reduce_and_update_fp8_tensors(forward=True) - - @classmethod - def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Copy the scaling factors and amaxes for recompute forward phase - to ensure both forward steps are numerically same. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - ] - - if buffer_position_key in fp8_meta: - cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) - else: - if len(cls.fp8_tensors_recompute_buffer) == 0: - cls.fp8_tensors_recompute_buffer = [deque()] - else: - cls.fp8_tensors_recompute_buffer.append(deque()) - cls.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 - - @classmethod - def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Switch to the copied scaling factors and amaxes from phase - 1 forward for indentical numerical outputs. - """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) - fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) - fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - - -@contextmanager -def fp8_model_init( - enabled: bool = True, - recipe: Optional[Recipe] = None, - preserve_high_precision_init_val: bool = False, -) -> None: - """ - Context manager for FP8 initialization of parameters. - - Example usage: - - .. code-block:: python - - with fp8_model_init(enabled=True): - model = transformer_engine.pytorch.Linear(768, 768) - - # Preserving high precision initial value to initialize master weight - with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): - model = transformer_engine.pytorch.Linear(768, 768) - master_weight = model.weight.get_high_precision_init_val() - model.weight.clear_high_precision_init_val() - - Parameters - ---------- - enabled: bool, default = `True` - when enabled, Transformer Engine modules created inside this `fp8_model_init` - region will hold only FP8 copies of its parameters, as opposed to the default - behavior where both higher precision and FP8 copies are present. Setting this - option to `True` may result in lower memory consumption and is especially - useful for scenarios like: - - * full model training using optimizer with master weights, where the high - precision copies of weights are already present in the optimizer. - * inference, where only the FP8 copies of the parameters are used. - * LoRA-like fine-tuning, where the main parameters of the model do not change. - recipe: transformer_engine.common.recipe.Recipe, default = `None` - Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. - preserve_high_precision_init_val: bool, default = `False` - when enabled, store the high precision tensor used to initialize FP8 parameters - in CPU memory, and add two function attributes named `get_high_precision_init_val()` - and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high - precision tensor. The purpose is that users can use this high-precision copy - to initialize master weights, avoiding the loss of precision that can occur when - using FP8 parameters directly. Note that after the master weights are initialized, - users should call `clear_high_precision_init_val()` to release this CPU memory. - - This functionality is *EXPERIMENTAL*. - """ - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE - _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL - FP8GlobalStateManager.FP8_PARAMETERS = enabled - FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val - try: - yield - finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters - FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val - - -@contextmanager -def fp8_autocast( - enabled: bool = True, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, -) -> None: - """ - Context manager for FP8 usage. - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `True` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.Recipe, default = `None` - recipe used for FP8 training. - fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - if enabled: - check_recipe_support(fp8_recipe) - fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter( - enabled=enabled, - calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, - _graph=_graph, - ) - try: - yield - finally: - FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) - FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) - - -def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: - """Update amax history and set next amax to zero.""" - if amax_history.shape[0] > 1: - new_amax_history = torch.roll(amax_history, -1, 0) - amax_history.copy_(new_amax_history) - amax_history[0].fill_(0.0) - return amax_history - - -@torch.jit.script -def _default_get_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Default function to obtain amax from history.""" - if amax_compute_algo == "max": - amax = torch.max(amax_history, dim=0).values - else: # amax_compute_algo == "most_recent" - amax = amax_history[0].clone() - - amax_history = _update_amax_history(amax_history) - return amax_history, amax - - -@jit_fuser -def _default_sf_compute( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - margin: int, - _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter -) -> torch.Tensor: - """Default function to convert amax to scaling factor. - Computing the scaling factor requires consideration of the following scenarios: - 1. amax == 0: - No action is possible, set scale to the previous scale (or 1). - 2. 0 < amax < tiny_amax - The amax is too tiny that the scale becomes infinite in FP32. - Set scale = FP32_max - 3. tiny_amax <= amax < FP32_max: - Set scale = FP8_max (or scaled_max) / amax - 4. When amax == inf or amax == nan: - No action is possible, set scale to the previous scale (or 1). - """ - sf = (fp8_max / amax) / (2**margin) - sf = torch.where(amax > 0.0, sf, scale) - sf = torch.where(torch.isfinite(amax), sf, scale) - sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) - scale.copy_(sf) - return scale - - -def _compute_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: Union[Callable, str], -) -> Tuple[torch.Tensor, torch.Tensor]: - """Obtain the amax from the history.""" - - if callable(amax_compute_algo): - amax = amax_compute_algo(amax_history) - amax_history = _update_amax_history(amax_history) - return amax_history, amax - return _default_get_amax_and_update_history( - amax_history, - amax_compute_algo, - ) - - -def _compute_scaling_factor( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> torch.Tensor: - """Convert amax to scaling factor.""" - - if recipe.scaling_factor_compute_algo is None: - return _default_sf_compute( - amax, - scale, - fp8_max, - recipe.margin, - ) - return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) - - -def _amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> None: - """Updates FP8 meta tensors.""" - new_amax_history, amax = _compute_amax_and_update_history( - amax_history, - recipe.amax_compute_algo, - ) - new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) - scale.copy_(new_scale) - amax_history.copy_(new_amax_history) - - -def split_and_copy( - buffer: torch.Tensor, - outputs: List[torch.Tensor], - chunk_sizes: List[int], -) -> None: - """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" - splits = buffer.split(chunk_sizes) - torch._foreach_copy_(outputs, splits) - - -class RecipeState(abc.ABC): - """Configuration and state for a quantization recipe. - - This is a builder class for quantizers, which are in turn builder - classes for quantized tensors. - - This class may pack together the state for multiple quantizers, - which is helpful for applying fused kernels with less overhead. - - """ - - @staticmethod - def create( - recipe: Recipe, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> RecipeState: - """Factory method to create the state for a quantization recipe - - Parameters - ---------- - recipe: Recipe - Quantization recipe. - mode: {"forward", "backward"} - Training stage where quantization will be performed. - num_quantizers: int, default = 1 - Number of quantizers to create state for. - device: torch.device, default = default CUDA device - Device for quantized tensors. - - Returns - ------- - RecipeState: - Quantization recipe state. - - """ - - cls = None - if recipe.delayed(): - cls = DelayedScalingRecipeState - elif recipe.mxfp8(): - cls = MXFP8BlockScalingRecipeState - elif recipe.float8_current_scaling(): - cls = Float8CurrentScalingRecipeState - elif recipe.float8_block_scaling(): - cls = Float8BlockScalingRecipeState - elif recipe.nvfp4(): - cls = NVFP4BlockScalingRecipeState - elif recipe.custom(): - cls = CustomRecipeState - else: - raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( - recipe, - mode=mode, - num_quantizers=num_quantizers, - device=device, - ) - - @abc.abstractmethod - def make_quantizers(self) -> list: - """Convert recipe state to quantizers. - - Quantizers are builder classes for quantized tensors. They are - typically used to convert a high-precision tensor (e.g. in - FP32 or BF16) into a quantized tensor (e.g. in FP8). - - """ - - -class DelayedScalingRecipeState(RecipeState): - """State for FP8 quantization with per-tensor delayed scaling. - - Delayed scaling recipe requires a scaling factor (applied when - casting to FP8) and a history of max-abs values ("amax") from - recent FP8 casts for updating the scaling factor. The scale update - is handled externally by `FP8GlobalStateManager`. - - """ - - recipe: DelayedScaling - mode: str - dtype: tex.DType - scale: torch.Tensor - amax_history: torch.Tensor - - def __init__( - self, - recipe: DelayedScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) - self.amax_history = torch.zeros( - recipe.amax_history_len, - num_quantizers, - dtype=torch.float32, - device=device, - ) - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_tensor import Float8Quantizer - - return [ - Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) - for i in range(self.num_quantizers) - ] - - -class Float8CurrentScalingRecipeState(RecipeState): - """Configuration for Per-tensor current scaling quantization. - - Per-tensor current quantization does not require state. - - """ - - recipe: Float8CurrentScaling - mode: str - dtype: tex.DType - device: torch.device - - def __init__( - self, - recipe: Float8CurrentScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - from .tensor.float8_tensor import Float8CurrentScalingQuantizer - - return [ - Float8CurrentScalingQuantizer( - self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales - ) - for i in range(self.num_quantizers) - ] - - -class MXFP8BlockScalingRecipeState(RecipeState): - """Configuration for MXFP8 quantization. - - MXFP8 quantization does not require state. - - """ - - recipe: MXFP8BlockScaling - mode: str - dtype: tex.DType - - def __init__( - self, - recipe: MXFP8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.mxfp8_tensor import MXFP8Quantizer - - return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] - - -class Float8BlockScalingRecipeState(RecipeState): - """Configuration for Float8BlockScaling quantization. - - Float8BlockScaling quantization does not require state, - but different quantizers use different modes. - """ - - recipe: Float8BlockScaling - mode: str - qx_dtype: tex.DType - qw_dtype: tex.DType - qgrad_dtype: tex.DType - - def __init__( - self, - recipe: Float8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.qx_dtype = get_fp8_te_dtype(recipe, True) - self.qw_dtype = get_fp8_te_dtype(recipe, True) - self.qgrad_dtype = get_fp8_te_dtype(recipe, False) - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_blockwise_tensor import Float8BlockQuantizer - - if self.mode == "forward": - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward, and doesn't play nicely with QuantizeOp, - # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qw_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 3) - ] - ) - ) - - assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 2) - ] - ) - ) - - -class NVFP4BlockScalingRecipeState(RecipeState): - """Configuration for NVFP4 quantization. - - NVFP4 quantization does not require state. - - """ - - recipe: NVFP4BlockScaling - mode: str - dtype: tex.DType - - def __init__( - self, - recipe: NVFP4BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp4_te_dtype(recipe) - - # Allocate buffers - if device is None: - device = torch.device("cuda") - - def make_quantizers(self) -> list: - from .tensor.nvfp4_tensor import NVFP4Quantizer - - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward. It assumes forward quantizers are - # ordered [input, weight, output, ...] and backward quantizers - # are ordered [grad_output, grad_input, ...]. This doesn't - # play nicely with fusible ops: Linear op doesn't own output - # or grad input quantizers, Quantize op only owns input and - # grad output quantizers. - - if self.mode == "forward": - - def _make_quantizer(idx: int) -> NVFP4Quantizer: - qparams = ( - self.recipe.fp4_quant_fwd_weight - if idx % 3 == 1 - else self.recipe.fp4_quant_fwd_inp - ) - return NVFP4Quantizer( - fp4_dtype=self.dtype, - rowwise=True, - columnwise=True, - with_rht=qparams.random_hadamard_transform, - with_post_rht_amax=qparams.random_hadamard_transform, - with_2d_quantization=qparams.fp4_2d_quantization, - stochastic_rounding=qparams.stochastic_rounding, - ) - - return [_make_quantizer(idx) for idx in range(self.num_quantizers)] - - if self.mode == "backward": - return [ - NVFP4Quantizer( - fp4_dtype=self.dtype, - rowwise=True, - columnwise=True, - with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, - stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - ) - for _ in range(self.num_quantizers) - ] - - raise RuntimeError(f"Unexpected recipe mode ({self.mode})") - - -class CustomRecipeState(RecipeState): - """State for CustomRecipe: produce quantizers per tensor.""" - - recipe: CustomRecipe - mode: str - num_quantizers: int - device: Optional[torch.device] - - def __init__( - self, - recipe: CustomRecipe, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - if device is None: - device = torch.device("cuda") - self.device = device - - if getattr(recipe, "qfactory", None) is None: - raise ValueError("CustomRecipe requires `qfactory`.") - - def make_quantizers(self) -> list: - qfactory = self.recipe.qfactory - out = [] - - # TODO(negvet): make_quantizers() should take roles from the operation - # Hardcode linear-specific roles for now - roles: List[str] - if self.mode == "forward": - roles = [ - ("linear_input", "linear_weight", "linear_output")[i % 3] - for i in range(self.num_quantizers) - ] - elif self.mode == "backward": - roles = [ - ("linear_grad_output", "linear_grad_input")[i % 2] - for i in range(self.num_quantizers) - ] - else: - roles = ["unknown"] * self.num_quantizers - - for i in range(self.num_quantizers): - # Get quantizer from the user defined factory - quantizer = qfactory(roles[i]) - out.append(quantizer) - return out +# Importing each function instead of 'import *' allows us specify '__all__' in +# quantize.py and also makes any newer additions to quantize.py invisible via +# fp8.py so that we don't reinforce importing internal TE functions. +from .quantization import ( + check_fp8_support, + check_mxfp8_support, + check_nvfp4_support, + check_fp8_block_scaling_support, + check_recipe_support, + get_default_fp8_recipe, + get_fp8_torch_dtype, + get_fp8_te_dtype, + get_fp4_te_dtype, + get_fp8_max, + FP8GlobalStateManager, + fp8_model_init, + fp8_autocast, + _update_amax_history, + _default_get_amax_and_update_history, + _default_sf_compute, + _compute_amax_and_update_history, + _compute_scaling_factor, + _amax_and_scale_update, + split_and_copy, + RecipeState, + DelayedScalingRecipeState, + Float8CurrentScalingRecipeState, + MXFP8BlockScalingRecipeState, + Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, + CustomRecipeState, +) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f0fe557c0..798d3209a 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -6,6 +6,7 @@ from collections.abc import Iterable import contextlib import gc +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -15,8 +16,8 @@ from transformer_engine.common.recipe import DelayedScaling, Recipe from transformer_engine.pytorch.constants import dist_group_type -from .fp8 import ( - fp8_autocast, +from .quantization import ( + autocast, FP8GlobalStateManager, get_default_fp8_recipe, ) @@ -84,7 +85,7 @@ def _make_graphed_callables( sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], num_warmup_iters: int = 3, allow_unused_input: bool = False, - fp8_weight_caching: bool = False, + cache_quantized_params: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, _num_layers_per_chunk: Optional[List[int]] = None, @@ -252,7 +253,7 @@ def _make_graphed_callables( consumed_sample_q[sample_keys].append(per_callable_fwd_idx) fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] - if fp8_weight_caching: + if cache_quantized_params: # Initialize flag that controls FP8 weight updates FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) @@ -687,7 +688,7 @@ def functionalized(*user_args, **user_kwargs): # Decide whether to update FP8 weights skip_fp8_weight_update = None - if fp8_weight_caching: + if cache_quantized_params: assert "is_first_microbatch" in user_kwargs and isinstance( user_kwargs["is_first_microbatch"], bool ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." @@ -796,14 +797,14 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: Optional[Recipe], + recipe: Optional[Recipe], ) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - if not isinstance(fp8_recipe, DelayedScaling): + if not isinstance(recipe, DelayedScaling): return None fp8_tensors = [] @@ -812,10 +813,10 @@ def save_fp8_tensors( module_tensors = None if isinstance(m, TransformerEngineBaseModule): if m.primary_weights_in_fp8: - m.adjust_amax_history_length(fp8_recipe.amax_history_len) + m.adjust_amax_history_length(recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() elif isinstance(m, BasicOperation): - m.reset_recipe_state(recipe=fp8_recipe) + m.reset_recipe_state(recipe=recipe) module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors @@ -850,11 +851,16 @@ def make_graphed_callables( num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, - fp8_enabled: SingleOrTuple[bool] = False, - fp8_calibrating: bool = False, + fp8_enabled: Optional[SingleOrTuple[bool]] = None, + fp8_calibrating: Optional[bool] = None, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, - fp8_weight_caching: bool = False, + fp8_weight_caching: Optional[bool] = None, + enabled: Optional[SingleOrTuple[bool]] = None, + calibrating: Optional[bool] = None, + recipe: Optional[Recipe] = None, + amax_reduction_group: Optional[dist_group_type] = None, + cache_quantized_params: Optional[bool] = None, _order: Optional[List[int]] = None, _num_layers_per_chunk: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, @@ -870,6 +876,11 @@ def make_graphed_callables( `original PyTorch implementation `_ for more documentation. + .. warning:: + + Arguments 'fp8_enabled', 'fp8_calibrating', 'fp8_recipe', 'fp8_group', and 'fp8_weight_caching' are deprecated. + Use arguments 'enabled', 'calibrating', 'recipe', 'amax_reduction_group', and 'cache_quantized_params' instead. + Graphing parameters ------------------- modules: (tuple of) callable @@ -894,30 +905,110 @@ def make_graphed_callables( when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. - FP8-related parameters + Quantization related parameters ---------------------- - fp8_enabled: (tuple of) bool, default = `False` - whether or not to enable fp8. - If tuple, the length must match the number of modules. - fp8_calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: Recipe, default = `None` - recipe used for FP8 training. - fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - fp8_weight_caching: bool, default = `False` - Whether or not to cache FP8 weights across microbatches. if set to `True`, - the `is_first_microbatch` boolean argument must be passed into the forward - method for TransformerEngine modules. When storing primary weights in FP8 - using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg - must be set to `False` if calculating weight transposes' outside TE, e.g., - in the optimizer step. + enabled: (tuple of) bool, default = `False` + whether or not to enable low precision quantization (FP8/FP4). + If tuple, the length must match the number of modules. + calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of quantized tensors even when executing without quantization enabled. + This is useful for saving an inference ready checkpoint while training + using a higher precision. + recipe: recipe.Recipe, default = `None` + recipe used for low precision quantization. + amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the quantized tensors + are reduced at the end of each training step. + cache_quantized_params: bool, default = `False` + Whether or not to cache quantized weights across microbatches. if set to `True`, + the `is_first_microbatch` boolean argument must be passed into the forward + method for TransformerEngine modules. When storing primary weights in low precision + using TE's `quantized_model_init` API and using an quantization aware optimizer, + this arg must be set to `False` if calculating weight transposes' outside TE, e.g., + in the optimizer step. """ + + # Handle deprecated args. If old kwargs are set, they are prioritized with warning. + if fp8_enabled is not None: + if enabled is not None: + raise ValueError( + "make_graphed_callables has deprecated `fp8_enabled` kwarg " + "in favor of `enabled`, but both kwargs are set." + ) + warnings.warn( + "make_graphed_callables has deprecated `fp8_enabled` kwarg in favor of `enabled`. " + "`fp8_enabled` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + enabled = fp8_enabled + if enabled is None: + enabled = False + + if fp8_calibrating is not None: + if calibrating is not None: + raise ValueError( + "make_graphed_callables has deprecated `fp8_calibrating` kwarg " + "in favor of `calibrating`, but both kwargs are set." + ) + warnings.warn( + "make_graphed_callables has deprecated `fp8_calibrating` kwarg in favor of " + "`calibrating`. `fp8_calibrating` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + calibrating = fp8_calibrating + if calibrating is None: + calibrating = False + + if fp8_recipe is not None: + if recipe is None: + warnings.warn( + "make_graphed_callables has deprecated `fp8_recipe` kwarg in favor of " + "`recipe`. `fp8_recipe` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + raise ValueError( + "make_graphed_callables has deprecated `fp8_recipe` kwarg " + "in favor of `recipe`, but both kwargs are set." + ) + recipe = fp8_recipe + + if fp8_group is not None: + if amax_reduction_group is None: + warnings.warn( + "make_graphed_callables has deprecated `fp8_group` kwarg in favor of " + "`amax_reduction_group`. `fp8_group` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + raise ValueError( + "make_graphed_callables has deprecated `fp8_group` kwarg " + "in favor of `amax_reduction_group`, but both kwargs are set." + ) + amax_reduction_group = fp8_group + + if fp8_weight_caching is not None: + if cache_quantized_params is not None: + raise ValueError( + "make_graphed_callables has deprecated `fp8_weight_caching` kwarg " + "in favor of `cache_quantized_params`, but both kwargs are set." + ) + warnings.warn( + "make_graphed_callables has deprecated `fp8_weight_caching` kwarg in favor of " + "`cache_quantized_params`. `fp8_weight_caching` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + cache_quantized_params = fp8_weight_caching + if cache_quantized_params is None: + cache_quantized_params = False + set_capture_start() # Handle single module. @@ -926,21 +1017,21 @@ def make_graphed_callables( just_one_callable = True modules = (modules,) - if not isinstance(fp8_enabled, tuple): - assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools" - fp8_enabled = (fp8_enabled,) * len(modules) + if not isinstance(enabled, tuple): + assert isinstance(enabled, bool), "enabled must be a bool or a tuple of bools" + enabled = (enabled,) * len(modules) else: - assert len(fp8_enabled) == len( + assert len(enabled) == len( modules - ), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})" - if any(fp8_enabled) and fp8_recipe is None: - fp8_recipe = get_default_fp8_recipe() - elif not any(fp8_enabled): - fp8_recipe = None - module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled)) + ), f"enabled length ({len(enabled)}) must match modules length ({len(modules)})" + if any(enabled) and recipe is None: + recipe = get_default_fp8_recipe() + elif not any(enabled): + recipe = None + module_uses_fp8 = dict(zip((id(m) for m in modules), enabled)) # Store FP8 tensors to reset later. - saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) + saved_fp8_tensors = save_fp8_tensors(modules, recipe=recipe) # FP8 wrapper. old_call_funcs = {} @@ -954,11 +1045,11 @@ def wrap_autocast(block): # Wrap the original call function of the module class. def call_func(self, *args, **kwargs): - with fp8_autocast( + with autocast( enabled=module_uses_fp8.get(id(self), False), - calibrating=fp8_calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, + calibrating=calibrating, + recipe=recipe, + amax_reduction_group=amax_reduction_group, _graph=True, ): outputs = old_call_funcs[block_cls](self, *args, **kwargs) @@ -992,7 +1083,7 @@ def call_func(self, *args, **kwargs): sample_args, num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input, - fp8_weight_caching=fp8_weight_caching, + cache_quantized_params=cache_quantized_params, sample_kwargs=sample_kwargs, _order=_order, _num_layers_per_chunk=_num_layers_per_chunk, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 838ac5281..d16455b5b 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -22,7 +22,7 @@ from transformer_engine.common.recipe import Recipe from ._common import _ParameterInitMeta, noop_cat -from ..fp8 import ( +from ..quantization import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, @@ -1574,8 +1574,8 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: - MXFP8BlockScaling → MXFP8Tensor - Float8BlockScaling → Float8BlockTensor - Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), - but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + Example case to check: recipe is DelayedScaling (DelayedScaling is set in autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in quantized_model_init()). """ if not self.fp8 and not self.fp8_calibration: return @@ -1596,6 +1596,6 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: raise RuntimeError( f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." - " Please check the recipes assigned during fp8_model_init() and" - " fp8_autocast() calls." + " Please check the recipes assigned during quantized_model_init() and" + " autocast() calls." ) diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 0d2e3e6d7..5d569d59d 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..jit import no_torch_dynamo diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 3b0f8928f..b74395dd8 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..jit import no_torch_dynamo diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index ec05f684b..53a32818b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -20,7 +20,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 824fcc0a7..05f2e9cde 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -27,7 +27,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, assert_dim_for_all_gather, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 8ef19d052..a2ddb970a 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -28,7 +28,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 12b7bac01..3069c21d9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -26,7 +26,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..utils import ( cast_if_needed, clear_tensor_data, diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 13db35fc7..52ca84b5d 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -11,7 +11,7 @@ from transformer_engine_torch import FP8TensorMeta from .. import torch_version -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor from ..tensor.quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index b15d840d6..432d8c134 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,7 +19,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...fp8 import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, Recipe from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -303,8 +303,8 @@ def reset_parameters(self) -> None: "Tried to quantize weight with deferred initialization " "due to meta device, but no quantizer was available. " "This is most likely because the weight was initialized " - "within fp8_model_init, but the forward pass was not " - "performed within fp8_autocast." + "within quantized_model_init, but the forward pass was not " + "performed within autocast." ) quantizer.set_usage( rowwise=True, diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index dcfc3c4f7..87c65d4b2 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -9,7 +9,7 @@ import torch -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from .._common import is_quantized_tensor from ..op import BasicOperation, OperationContext from ...tensor import Quantizer @@ -18,8 +18,8 @@ class Quantize(BasicOperation): """Quantize tensor data - Uses FP8 recipe from `fp8_autocast` context. When called outside - of an `fp8_autocast` context, this is an identity operation. + Uses recipe from `autocast` context. When called outside + of an `autocast` context, this is an identity operation. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 40510c856..7897ef164 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import Recipe +from transformer_engine.pytorch.quantization import Recipe from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index ab271e17b..74bd3d1b3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -11,7 +11,7 @@ import torch from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias from ..op import FusedOperation, FusibleOperation, OperationContext diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 4831ae407..6d5d55339 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -11,7 +11,7 @@ import torch from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias from ..op import FusedOperation, FusibleOperation, OperationContext diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 72e17f64e..24788bcdf 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -11,7 +11,7 @@ import torch from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..op import ( diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index cbbe529d6..e20de53da 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -14,7 +14,7 @@ from ...cpp_extensions import general_gemm from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...module.base import ( fill_userbuffers_buffer_for_all_gather, get_ub, diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6f80a7a1f..8ae112022 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -11,7 +11,7 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 095e3e89e..639817ada 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,10 +14,10 @@ import torch from transformer_engine.common.recipe import Recipe -from ..fp8 import ( +from ..quantization import ( FP8GlobalStateManager, RecipeState, - fp8_autocast, + autocast, ) from ..tensor import Quantizer @@ -634,7 +634,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: # Get op's quantizer state, initializing if needed if self._fp8_metas is None or self._fp8_metas[mode] is None: - with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + with autocast(recipe=state[mode]["recipe"]): self.reset_recipe_state(recipe=state[mode]["recipe"]) fp8_meta = self._fp8_metas[mode] diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py new file mode 100644 index 000000000..030370b9d --- /dev/null +++ b/transformer_engine/pytorch/quantization.py @@ -0,0 +1,1396 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization utilities for TransformerEngine""" +from __future__ import annotations + +import abc +import itertools +import functools +import warnings +import os +from contextlib import contextmanager +from collections import deque +from typing import Callable, List, Optional, Dict, Any, Tuple, Union + +import torch +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, + Float8BlockScaling, + NVFP4BlockScaling, + CustomRecipe, +) + +from .constants import dist_group_type +from .utils import get_device_compute_capability +from .jit import jit_fuser + + +__all__ = [ + "autocast", + "quantized_model_init", + "is_fp8_available", + "is_mxfp8_available", + "is_fp8_block_scaling_available", + "is_nvfp4_available", + "get_default_recipe", +] + + +@functools.lru_cache(maxsize=None) +def check_fp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if get_device_compute_capability() >= (9, 0): # hopper and above + return True, "" + if get_device_compute_capability() < (8, 9): # pre-ada + return False, "Device compute capability 8.9 or higher required for FP8 execution." + if tex.get_cublasLt_version() < 120103: + return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." + if float(torch.version.cuda) < 12.1: + return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." + return True, "" + + +@functools.lru_cache(maxsize=None) +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if get_device_compute_capability() >= (12, 0): + return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for MXFP8 execution." + + +@functools.lru_cache(maxsize=None) +def check_nvfp4_support() -> Tuple[bool, str]: + """Return if nvfp4 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + + +@functools.lru_cache(maxsize=None) +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: + return True, "" + return ( + False, + "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", + ) + + +def check_recipe_support(recipe: Recipe) -> None: + """Check if the given recipe is supported.""" + recipe_supported = True + unsupported_reason = "" + if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): + recipe_supported, unsupported_reason = check_fp8_support() + elif isinstance(recipe, Float8BlockScaling): + recipe_supported, unsupported_reason = check_fp8_block_scaling_support() + elif isinstance(recipe, MXFP8BlockScaling): + recipe_supported, unsupported_reason = check_mxfp8_support() + assert recipe_supported, unsupported_reason + + +def get_default_fp8_recipe() -> Recipe: + """FP8 recipe with default args.""" + if check_mxfp8_support()[0]: + return MXFP8BlockScaling() + if get_device_compute_capability() >= (12, 0): + # This is a temporary restriction until MXFP8 is supported for all gemm layouts. + return Float8CurrentScaling() + return DelayedScaling() + + +def get_default_recipe() -> Recipe: + """Returns the default training recipe based on available device.""" + return get_default_fp8_recipe() + + +def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: + """Get fp8 data type according to recipe and tensor""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return torch.float8_e4m3fn + return torch.float8_e5m2 + + +def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: + """Get fp8 data type according to recipe and tensor""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return tex.DType.kFloat8E4M3 + return tex.DType.kFloat8E5M2 + + +def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: + """Get fp4 data type according to recipe and tensor""" + if fp4_recipe.fp4_format == Format.E2M1: + return tex.DType.kFloat4E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: + """Get max representible FP8 value.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return Format.E4M3.value.max_fwd + return Format.E5M2.value.max_fwd + + +def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if FP8 support is available for the delayed + scaling and per tensor current scaling recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for FP8 is available. + + """ + if return_reason: + return check_fp8_support() + return check_fp8_support()[0] + + +def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for the MXFP8 recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for MXFP8 is available. + + """ + if return_reason: + return check_mxfp8_support() + return check_mxfp8_support()[0] + + +def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for the FP8 block scaling recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for FP8 block scaling is available. + + """ + if return_reason: + return check_fp8_block_scaling_support() + return check_fp8_block_scaling_support()[0] + + +def is_nvfp4_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for the NVFP4 recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for NVFP4 is available. + + """ + if return_reason: + return check_nvfp4_support() + return check_nvfp4_support()[0] + + +class FP8GlobalStateManager: + """Class to keep track of and manipulate the global + FP8 state at different stages of execution. + """ + + FP8_ENABLED = False + FP8_CALIBRATION = False + FP8_RECIPE = None + FP8_DISTRIBUTED_GROUP = None + FP8_PARAMETERS = False + HIGH_PRECISION_INIT_VAL = False + IS_FIRST_FP8_MODULE = False + FP8_GRAPH_CAPTURING = False + AUTOCAST_DEPTH = 0 + global_amax_buffer = {} + global_amax_history_buffer = {} + global_scale_buffer = {} + fp8_tensors_recompute_buffer = [] + fp8_available = None + reason_for_no_fp8 = "" + autocast_arguments = {} + skip_fp8_weight_update_tensor = None + mxfp8_available = None + reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None + nvfp4_available = None + reason_for_no_nvfp4 = "" + + @classmethod + def reset(cls) -> None: + """Reset the global state""" + cls.FP8_ENABLED = False + cls.FP8_CALIBRATION = False + cls.FP8_RECIPE = None + cls.FP8_DISTRIBUTED_GROUP = None + cls.FP8_PARAMETERS = False + cls.HIGH_PRECISION_INIT_VAL = False + cls.IS_FIRST_FP8_MODULE = False + cls.FP8_GRAPH_CAPTURING = False + cls.AUTOCAST_DEPTH = 0 + cls.global_amax_buffer = {} + cls.global_amax_history_buffer = {} + cls.global_scale_buffer = {} + cls.fp8_tensors_recompute_buffer = [] + cls.fp8_available = None + cls.reason_for_no_fp8 = "" + cls.autocast_arguments = {} + cls.skip_fp8_weight_update_tensor = None + cls.mxfp8_available = None + cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" + + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """`skip_fp8_weight_update_tensor` inplace setter.""" + if cls.skip_fp8_weight_update_tensor is None: + cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + cls.skip_fp8_weight_update_tensor.fill_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> None: + """`skip_fp8_weight_update_tensor` getter.""" + return cls.skip_fp8_weight_update_tensor + + @classmethod + def is_fp8_available(cls) -> Tuple[bool, str]: + """Return if fp8 support is available""" + return check_fp8_support() + + @classmethod + def is_mxfp8_available(cls) -> Tuple[bool, str]: + """Return if MXFP8/current scaling support is available.""" + return check_mxfp8_support() + + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + return check_fp8_block_scaling_support() + + @classmethod + def is_nvfp4_available(cls) -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + return check_nvfp4_support() + + @staticmethod + def get_meta_tensor_key(forward: bool = True) -> str: + """Returns scaling key in `fp8_meta`.""" + if forward: + return "scaling_fwd" + return "scaling_bwd" + + @staticmethod + def get_fwd_bwd_key(forward: bool = True) -> str: + """Convert bool `forward` to string.""" + return "forward" if forward else "backward" + + @classmethod + def get_buffer_info(cls) -> str: + """ + Returns a key for `fp8_meta` that stores the module's index + in the global buffers along with autocast information. + """ + return "buffer_index_and_autocast_key" + + @classmethod + def get_key_in_buffer( + cls, + forward: bool, + fp8_recipe: Recipe, + fp8_group: dist_group_type, + ) -> str: + """Returns a key into the global FP8 buffers.""" + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + return f"{fwd_bwd_key}_{autocast_key}" + + @classmethod + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: + """Splits buffer key into relevant parts.""" + forward, autocast_key = key.split("_", 1) + forward = forward == "forward" + return forward, autocast_key + + @classmethod + def add_fp8_tensors_to_global_buffer( + cls, + fp8_meta: Dict[str, Any], + ) -> None: + """ + Delayed scaling only. + + The amax reduction process happens completely outside the FP8 modules. + To participate in the reduction, the only role played by a module is + to call this function in order to append it's FP8 tensor into a global + buffer. There are 5 global buffers maintained, one each for amax, amax + history, scale, scale-inverse, and non-weight-mask. Each buffer has + keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix + to indicate the type of FP8 tensor, since the forward and backward + reductions happen separately. + + Note: For CG capture, this method is called from the graphed + wrapper. For non CG case, it's called from within the module. + """ + + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. + if index_in_buffer in fp8_meta: + return + + fp8_meta[index_in_buffer] = [] + for forward in (True, False): + fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) + if fp8_meta_tensor_key not in fp8_meta: + # Handles non-parameter FP8 modules, e.g. DPA. + continue + + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) + + if key not in cls.global_amax_buffer: + cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] + cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + else: + cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_history_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history + ) + cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + fp8_meta[index_in_buffer].append(key) + + @classmethod + def is_fp8_enabled(cls) -> bool: + """Is FP8 enabled""" + return cls.FP8_ENABLED + + @classmethod + def is_fp8_calibration(cls) -> bool: + """Is FP8 calibration""" + return cls.FP8_CALIBRATION + + @classmethod + def with_fp8_parameters(cls) -> bool: + """Should the parameters be stored as FP8""" + return cls.FP8_PARAMETERS + + @classmethod + def with_high_precision_init_val(cls) -> bool: + """Should the high precision initial values be stored with FP8 parameters""" + return cls.HIGH_PRECISION_INIT_VAL + + @classmethod + def fp8_graph_capturing(cls) -> bool: + """Is CUDA graph capture under way?""" + return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + + @classmethod + def is_first_fp8_module(cls): + """Returns `True` only the first time when called multiple + times from within the same `autocast` context. + """ + tmp = cls.IS_FIRST_FP8_MODULE + cls.IS_FIRST_FP8_MODULE = False + return tmp + + @classmethod + def get_fp8_recipe(cls) -> Recipe: + """Return the fp8 recipe""" + if cls.FP8_RECIPE is not None: + return cls.FP8_RECIPE + return get_default_fp8_recipe() + + @classmethod + def get_fp8_group(cls) -> Union[dist_group_type, None]: + """Return the fp8 group for scale/amax comm""" + return cls.FP8_DISTRIBUTED_GROUP + + @classmethod + def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: + """FP8 autocast state getter""" + return ( + cls.FP8_ENABLED, + cls.FP8_CALIBRATION, + cls.FP8_RECIPE, + cls.FP8_DISTRIBUTED_GROUP, + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING, + ) + + @classmethod + def set_autocast_state( + cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] + ) -> None: + """FP8 autocast state setter""" + ( + cls.FP8_ENABLED, + cls.FP8_CALIBRATION, + cls.FP8_RECIPE, + cls.FP8_DISTRIBUTED_GROUP, + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING, + ) = fp8_state + + @staticmethod + def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: + """Reduce tensor across given group.""" + if torch.distributed.is_initialized(): + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MAX, + group=group, + async_op=False, + ) + + @classmethod + def reduce_and_update_fp8_tensors( + cls, + forward: bool = True, + ) -> None: + """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" + # global_amax_buffer should only be non-empty for fp8 delayed scaling + for buffer_key, amax_buffer in cls.global_amax_buffer.items(): + # Check for forward or backward reduction. + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) + if fwd_update != forward: + continue + if len(amax_buffer) == 0: + continue + + # Retrieve autocast specific args and concat amaxes. + recipe, group = cls.autocast_arguments[autocast_key] + contiguous_amax = torch.cat(amax_buffer) + + # Reduction. + if ( + recipe.reduce_amax + and torch.distributed.is_initialized() + and torch.distributed.get_world_size(group=group) > 1 + ): + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) + + # Amax and scale update. + unfused_update = ( + bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) + or callable(recipe.amax_compute_algo) + or callable(recipe.scaling_factor_compute_algo) + ) + + if not unfused_update: + tex.fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + recipe.amax_compute_algo, + get_fp8_te_dtype(recipe, forward), + recipe.margin, + ) + else: + split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) + + for amax_history, scale in zip( + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + ): + _amax_and_scale_update( + amax_history, scale, get_fp8_max(recipe, forward), recipe + ) + + @classmethod + def get_unique_autocast_key( + cls, + recipe: Optional[Recipe] = None, + group: Optional[dist_group_type] = None, + ): + """ + For FP8, each autocast can be uniquely identified by the recipe and fp8 group. + Safely using `hash` as we never cross checkpoint boundaries. + """ + return f"{str(recipe)}:{hash(group)}" + + @classmethod + def autocast_enter( + cls, + enabled: bool = False, + calibrating: bool = False, + fp8_recipe: Optional[Recipe] = None, + fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, + ) -> None: + """Set state and tracking variables for entry into FP8 region.""" + + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + + cls.FP8_ENABLED = enabled + cls.FP8_CALIBRATION = calibrating + cls.FP8_RECIPE = fp8_recipe + cls.FP8_DISTRIBUTED_GROUP = fp8_group + cls.FP8_GRAPH_CAPTURING = _graph + + if cls.AUTOCAST_DEPTH == 0: + cls.IS_FIRST_FP8_MODULE = True + cls.AUTOCAST_DEPTH += 1 + + if enabled: + fp8_available, reason_for_no_fp8 = cls.is_fp8_available() + assert fp8_available, reason_for_no_fp8 + if isinstance(fp8_recipe, MXFP8BlockScaling): + mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() + assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block + if isinstance(fp8_recipe, NVFP4BlockScaling): + nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() + assert nvfp4_available, reason_for_no_nvfp4 + + @classmethod + def autocast_exit(cls, enabled: bool, _graph: bool) -> None: + """Set state and tracking variables for exit from FP8 region.""" + cls.AUTOCAST_DEPTH -= 1 + # Reduce only the non-FP8 weight modules here. + # FP8 weight modules are reduced at the end of the optimizer + # step after the weight amax is populated. + if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + # delayed scaling only function, for other recipes (current scaling with any granularity), + # this is noop for other recipes because cls.global_amax_buffer is empty list + cls.reduce_and_update_fp8_tensors(forward=True) + + @classmethod + def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: + """Copy the scaling factors and amaxes for recompute forward phase + to ensure both forward steps are numerically same. + """ + + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" + + to_copy = [ + fp8_meta["scaling_fwd"].amax_history.clone(), + fp8_meta["scaling_fwd"].scale.clone(), + ] + + if buffer_position_key in fp8_meta: + cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) + else: + if len(cls.fp8_tensors_recompute_buffer) == 0: + cls.fp8_tensors_recompute_buffer = [deque()] + else: + cls.fp8_tensors_recompute_buffer.append(deque()) + cls.fp8_tensors_recompute_buffer[-1].append(to_copy) + fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 + + @classmethod + def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: + """Switch to the copied scaling factors and amaxes from phase + 1 forward for indentical numerical outputs. + """ + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + # Store updated amaxes and scales from phase 1 post forward. + fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() + fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() + + # Retrieve stashed amaxes and scales from phase 1 pre forward. + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" + stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() + + # Replace amaxes and scales with stashed values for phase 2 forward + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) + + @staticmethod + def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: + """Restore latest scaling factors and amaxes after recompute forward run.""" + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + + +@contextmanager +def fp8_model_init( + enabled: bool = True, + recipe: Optional[Recipe] = None, + preserve_high_precision_init_val: bool = False, +) -> None: + """ + .. warning:: + + fp8_model_init is deprecated and will be removed in a future release. Use + quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead. + + """ + + warnings.warn( + "fp8_model_init is deprecated and will be removed in a future release. " + "Use quantized_model_init(" + "enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + # Call new implementation. + with quantized_model_init( + enabled=enabled, + recipe=recipe, + preserve_high_precision_init_val=preserve_high_precision_init_val, + ): + yield + + +@contextmanager +def quantized_model_init( + enabled: bool = True, + recipe: Optional[Recipe] = None, + preserve_high_precision_init_val: bool = False, +) -> None: + """ + Context manager for initialization of quantized parameters. + + Example usage: + + .. code-block:: python + + with quantized_model_init(enabled=True): + model = transformer_engine.pytorch.Linear(768, 768) + + # Preserving high precision initial value to initialize master weight + with quantized_model_init(enabled=True, preserve_high_precision_init_val=True): + model = transformer_engine.pytorch.Linear(768, 768) + master_weight = model.weight.get_high_precision_init_val() + model.weight.clear_high_precision_init_val() + + Parameters + ---------- + enabled: bool, default = `True` + when enabled, Transformer Engine modules created inside this `quantized_model_init` + region will hold only quantized copies of its parameters, as opposed to the default + behavior where both higher precision and quantized copies are present. Setting this + option to `True` may result in lower memory consumption and is especially + useful for scenarios like: + + * full model training using optimizer with master weights, where the high + precision copies of weights are already present in the optimizer. + * inference, where only the quantized copies of the parameters are used. + * LoRA-like fine-tuning, where the main parameters of the model do not change. + recipe: transformer_engine.common.recipe.Recipe, default = `None` + Recipe used to create the parameters. If left to None, it uses the default recipe. + preserve_high_precision_init_val: bool, default = `False` + when enabled, store the high precision tensor used to initialize quantized parameters + in CPU memory, and add two function attributes named `get_high_precision_init_val()` + and `clear_high_precision_init_val()` to quantized parameters to get/clear this high + precision tensor. The purpose is that users can use this high-precision copy + to initialize master weights, avoiding the loss of precision that can occur when + using quantized parameters directly. Note that after the master weights are initialized, + users should call `clear_high_precision_init_val()` to release this CPU memory. + + This functionality is *EXPERIMENTAL*. + """ + + _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE + _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL + FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val + try: + yield + finally: + FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val + + +@contextmanager +def fp8_autocast( + enabled: bool = True, + calibrating: bool = False, + fp8_recipe: Optional[Recipe] = None, + fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, +) -> None: + """ + .. warning:: + + fp8_autocast is deprecated and will be removed in a future release. + Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead. + + """ + + warnings.warn( + "fp8_autocast is deprecated and will be removed in a future release. " + "Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + # Call new implementation. + with autocast( + enabled=enabled, + calibrating=calibrating, + recipe=fp8_recipe, + amax_reduction_group=fp8_group, + _graph=_graph, + ): + yield + + +@contextmanager +def autocast( + enabled: bool = True, + calibrating: bool = False, + recipe: Optional["Recipe"] = None, + amax_reduction_group: Optional["dist_group_type"] = None, + _graph: bool = False, +) -> None: + """ + Context manager for quantization schemes like FP8 or FP4. + + .. code-block:: python + + with autocast(enabled=True): + out = model(inp) + + .. note:: + + Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors + with shapes where both dimensions are divisible by 16. In terms of the input to the full + Transformer network, this typically requires padding sequence length to be multiple of 16. + + .. note:: + + When :attr:`recipe.reduce_amax==True`, any module must not be invoked more than once + inside a single `autocast` region. This is unsupported behavior because the amax + reduction is handled during the exit of the `autocast` context. Calling the same + module more than once inside an `autocast` region overrides the amax tensors + before reduction can occur. + + Parameters + ---------- + enabled: bool, default = `True` + whether or not to enable low precision quantization (FP8/FP4). + calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of quantized tensors even when executing without quantization enabled. + This is useful for saving an inference ready checkpoint while training + using a higher precision. + recipe: recipe.Recipe, default = `None` + recipe used for low precision quantization. + amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the quantized tensors + are reduced at the end of each training step. + """ + + if enabled: + check_recipe_support(recipe) + + # Save current state so we always restore it on exit. + fp8_state = FP8GlobalStateManager.get_autocast_state() + + FP8GlobalStateManager.autocast_enter( + enabled=enabled, + calibrating=calibrating, + fp8_recipe=recipe, + fp8_group=amax_reduction_group, + _graph=_graph, + ) + try: + yield + finally: + FP8GlobalStateManager.set_autocast_state(fp8_state) + FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + + +def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: + """Update amax history and set next amax to zero.""" + if amax_history.shape[0] > 1: + new_amax_history = torch.roll(amax_history, -1, 0) + amax_history.copy_(new_amax_history) + amax_history[0].fill_(0.0) + return amax_history + + +@torch.jit.script +def _default_get_amax_and_update_history( + amax_history: torch.Tensor, + amax_compute_algo: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Default function to obtain amax from history.""" + if amax_compute_algo == "max": + amax = torch.max(amax_history, dim=0).values + else: # amax_compute_algo == "most_recent" + amax = amax_history[0].clone() + + amax_history = _update_amax_history(amax_history) + return amax_history, amax + + +@jit_fuser +def _default_sf_compute( + amax: torch.Tensor, + scale: torch.Tensor, + fp8_max: float, + margin: int, + _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter +) -> torch.Tensor: + """Default function to convert amax to scaling factor. + Computing the scaling factor requires consideration of the following scenarios: + 1. amax == 0: + No action is possible, set scale to the previous scale (or 1). + 2. 0 < amax < tiny_amax + The amax is too tiny that the scale becomes infinite in FP32. + Set scale = FP32_max + 3. tiny_amax <= amax < FP32_max: + Set scale = FP8_max (or scaled_max) / amax + 4. When amax == inf or amax == nan: + No action is possible, set scale to the previous scale (or 1). + """ + sf = (fp8_max / amax) / (2**margin) + sf = torch.where(amax > 0.0, sf, scale) + sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) + scale.copy_(sf) + return scale + + +def _compute_amax_and_update_history( + amax_history: torch.Tensor, + amax_compute_algo: Union[Callable, str], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Obtain the amax from the history.""" + + if callable(amax_compute_algo): + amax = amax_compute_algo(amax_history) + amax_history = _update_amax_history(amax_history) + return amax_history, amax + return _default_get_amax_and_update_history( + amax_history, + amax_compute_algo, + ) + + +def _compute_scaling_factor( + amax: torch.Tensor, + scale: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, +) -> torch.Tensor: + """Convert amax to scaling factor.""" + + if recipe.scaling_factor_compute_algo is None: + return _default_sf_compute( + amax, + scale, + fp8_max, + recipe.margin, + ) + return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) + + +def _amax_and_scale_update( + amax_history: torch.Tensor, + scale: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, +) -> None: + """Updates FP8 meta tensors.""" + new_amax_history, amax = _compute_amax_and_update_history( + amax_history, + recipe.amax_compute_algo, + ) + new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) + scale.copy_(new_scale) + amax_history.copy_(new_amax_history) + + +def split_and_copy( + buffer: torch.Tensor, + outputs: List[torch.Tensor], + chunk_sizes: List[int], +) -> None: + """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" + splits = buffer.split(chunk_sizes) + torch._foreach_copy_(outputs, splits) + + +class RecipeState(abc.ABC): + """Configuration and state for a quantization recipe. + + This is a builder class for quantizers, which are in turn builder + classes for quantized tensors. + + This class may pack together the state for multiple quantizers, + which is helpful for applying fused kernels with less overhead. + + """ + + @staticmethod + def create( + recipe: Recipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> RecipeState: + """Factory method to create the state for a quantization recipe + + Parameters + ---------- + recipe: Recipe + Quantization recipe. + mode: {"forward", "backward"} + Training stage where quantization will be performed. + num_quantizers: int, default = 1 + Number of quantizers to create state for. + device: torch.device, default = default CUDA device + Device for quantized tensors. + + Returns + ------- + RecipeState: + Quantization recipe state. + + """ + + cls = None + if recipe.delayed(): + cls = DelayedScalingRecipeState + elif recipe.mxfp8(): + cls = MXFP8BlockScalingRecipeState + elif recipe.float8_current_scaling(): + cls = Float8CurrentScalingRecipeState + elif recipe.float8_block_scaling(): + cls = Float8BlockScalingRecipeState + elif recipe.nvfp4(): + cls = NVFP4BlockScalingRecipeState + elif recipe.custom(): + cls = CustomRecipeState + else: + raise ValueError(f"{recipe.__class__.__name__} is not supported") + return cls( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + + @abc.abstractmethod + def make_quantizers(self) -> list: + """Convert recipe state to quantizers. + + Quantizers are builder classes for quantized tensors. They are + typically used to convert a high-precision tensor (e.g. in + FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + +class DelayedScalingRecipeState(RecipeState): + """State for FP8 quantization with per-tensor delayed scaling. + + Delayed scaling recipe requires a scaling factor (applied when + casting to FP8) and a history of max-abs values ("amax") from + recent FP8 casts for updating the scaling factor. The scale update + is handled externally by `FP8GlobalStateManager`. + + """ + + recipe: DelayedScaling + mode: str + dtype: tex.DType + scale: torch.Tensor + amax_history: torch.Tensor + + def __init__( + self, + recipe: DelayedScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) + self.amax_history = torch.zeros( + recipe.amax_history_len, + num_quantizers, + dtype=torch.float32, + device=device, + ) + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_tensor import Float8Quantizer + + return [ + Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) + for i in range(self.num_quantizers) + ] + + +class Float8CurrentScalingRecipeState(RecipeState): + """Configuration for Per-tensor current scaling quantization. + + Per-tensor current quantization does not require state. + + """ + + recipe: Float8CurrentScaling + mode: str + dtype: tex.DType + device: torch.device + + def __init__( + self, + recipe: Float8CurrentScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + from .tensor.float8_tensor import Float8CurrentScalingQuantizer + + return [ + Float8CurrentScalingQuantizer( + self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales + ) + for i in range(self.num_quantizers) + ] + + +class MXFP8BlockScalingRecipeState(RecipeState): + """Configuration for MXFP8 quantization. + + MXFP8 quantization does not require state. + + """ + + recipe: MXFP8BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: MXFP8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.mxfp8_tensor import MXFP8Quantizer + + return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) + + +class NVFP4BlockScalingRecipeState(RecipeState): + """Configuration for NVFP4 quantization. + + NVFP4 quantization does not require state. + + """ + + recipe: NVFP4BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + from .tensor.nvfp4_tensor import NVFP4Quantizer + + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward. It assumes forward quantizers are + # ordered [input, weight, output, ...] and backward quantizers + # are ordered [grad_output, grad_input, ...]. This doesn't + # play nicely with fusible ops: Linear op doesn't own output + # or grad input quantizers, Quantize op only owns input and + # grad output quantizers. + + if self.mode == "forward": + + def _make_quantizer(idx: int) -> NVFP4Quantizer: + qparams = ( + self.recipe.fp4_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp4_quant_fwd_inp + ) + return NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, + stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + ) + for _ in range(self.num_quantizers) + ] + + raise RuntimeError(f"Unexpected recipe mode ({self.mode})") + + +class CustomRecipeState(RecipeState): + """State for CustomRecipe: produce quantizers per tensor.""" + + recipe: CustomRecipe + mode: str + num_quantizers: int + device: Optional[torch.device] + + def __init__( + self, + recipe: CustomRecipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + if device is None: + device = torch.device("cuda") + self.device = device + + if getattr(recipe, "qfactory", None) is None: + raise ValueError("CustomRecipe requires `qfactory`.") + + def make_quantizers(self) -> list: + qfactory = self.recipe.qfactory + out = [] + + # TODO(negvet): make_quantizers() should take roles from the operation + # Hardcode linear-specific roles for now + roles: List[str] + if self.mode == "forward": + roles = [ + ("linear_input", "linear_weight", "linear_output")[i % 3] + for i in range(self.num_quantizers) + ] + elif self.mode == "backward": + roles = [ + ("linear_grad_output", "linear_grad_input")[i % 2] + for i in range(self.num_quantizers) + ] + else: + roles = ["unknown"] * self.num_quantizers + + for i in range(self.num_quantizers): + # Get quantizer from the user defined factory + quantizer = qfactory(roles[i]) + out.append(quantizer) + return out diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index b1a7e3731..2be0aed4a 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -16,6 +16,9 @@ from ..debug.pytorch.debug_quantization import DebugQuantizedTensor +__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] + + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" for tensor in tensors: @@ -453,13 +456,36 @@ def assert_dim_for_all_gather( ) -def is_bf16_compatible() -> None: +def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher. """ return torch.cuda.get_device_capability()[0] >= 8 +def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine whether bfloat16 (BF16) computation is supported on the current device. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating BF16 availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when BF16 is not available. When BF16 is available, + the reason will be an empty string. + + """ + available = is_bf16_compatible() + if not return_reason: + return available + + reason = ( + "" if available else "BF16 support requires a GPU with compute capability 8.0 or higher." + ) + return available, reason + + @functools.lru_cache(maxsize=None) def is_non_tn_fp8_gemm_supported() -> bool: """Checks whether the device supports From fd2f589f26dfedeb4b3705f0137ec89fe8e5d8cf Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 14 Oct 2025 14:57:34 -0700 Subject: [PATCH 056/141] [PyTorch] Bump minimum cuDNN version for fused attention with FP8 current scaling (#2236) * Require cuDNN 9.14.0+ for fused attention with FP8 current scaling Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Kirthi Shankar Sivamani --- .../attention/dot_product_attention/utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c8cc3d29f..b45edc716 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -469,13 +469,13 @@ def get_attention_backend( fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - if ( - use_fused_attention - and fp8_recipe.float8_current_scaling() - and device_compute_capability < (10, 0) - ): - logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") - use_fused_attention = False + if use_fused_attention and fp8_recipe.float8_current_scaling(): + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False + elif cudnn_version < (9, 14, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0") + use_fused_attention = False # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size From 4c572f040d3b8e9a29f512c9537b4f4a526f843e Mon Sep 17 00:00:00 2001 From: Paul Gibbons <87940629+paul-gibbons@users.noreply.github.com> Date: Wed, 15 Oct 2025 03:33:53 -0400 Subject: [PATCH 057/141] [PyTorch Debug] Fix issue with start_end_list logging feature (#2252) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fixes for start_end_list usage in TE debug Signed-off-by: Paul Gibbons * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Paul Gibbons Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> --- .../debug/features/log_fp8_tensor_stats.py | 12 +++++++++--- .../debug/features/log_tensor_stats.py | 12 +++++++++--- .../debug/features/utils/stats_buffer.py | 10 +++++++++- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index 31620211d..d09fb1057 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -290,10 +290,16 @@ def inspect_tensor( for stat in config["stats"]: self.check_if_stat_is_supported(stat, recipe_name) + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + options = ( - config.get("start_step", None), - config.get("end_step", None), - config.get("start_end_list", None), + start_step, + end_step, + start_end_list, "fp8", ) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 5d721d996..e917cf9a0 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -130,10 +130,16 @@ def inspect_tensor( " log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors." ) + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + options = ( - config.get("start_step", None), - config.get("end_step", None), - config.get("start_end_list", None), + start_step, + end_step, + start_end_list, ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index f07602d23..20236fb95 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -172,11 +172,19 @@ def _if_run_reduction(self) -> bool: if self.at_least_one_layer_fed: return True iteration = TEDebugState.get_iteration() - for _, next_iter in self.layers_to_next_iter.items(): + layers_to_remove = [] + for layer_name, next_iter in self.layers_to_next_iter.items(): + # When next_iter is None the feature will no longer run. + if next_iter is None: + layers_to_remove.append(layer_name) + continue # Note that layer can be not run for many iterations, # in this case we will synchronize until every step until we get any information from it. if iteration >= next_iter: return True + + for layer_name in layers_to_remove: + self.layers_to_next_iter.pop(layer_name, None) return False def reset(self): From 88564d594644527255adfc4df52a391996def736 Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Wed, 15 Oct 2025 17:08:29 -0500 Subject: [PATCH 058/141] README - latest news update (#2273) * Enhance Latest News section with recent TE and FP8 developments - Adds NVFP4 pretraining research paper with PR #2177 reference Signed-off-by: Santosh Bhavani * update nvfp4 reference Signed-off-by: Santosh Bhavani * Update README.rst Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Santosh Bhavani Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- README.rst | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/README.rst b/README.rst index 380a99edf..9b65c60ae 100644 --- a/README.rst +++ b/README.rst @@ -12,6 +12,14 @@ Transformer Engine Latest News =========== + +* [09/2025] `Pretraining Large Language Models with NVFP4 `_ +* [09/2025] `Native FP8 Mixed Precision Training for Ling 2.0, Open Sourced! `_ +* [09/2025] `Faster Training Throughput in FP8 Precision with NVIDIA NeMo `_ +* [08/2025] `How we built DeepL's next-generation LLMs with FP8 for training and inference `_ +* [08/2025] `NVFP4 Trains with Precision of 16-bit and Speed and Efficiency of 4-bit `_ +* [06/2025] `Floating Point 8: An Introduction to Efficient, Lower-Precision AI Training `_ +* [05/2025] `Advanced Optimization Strategies for LLM Training on NVIDIA Grace Hopper `_ * [03/2025] `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 `_ * [03/2025] `Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking `_ From 452c73746285407ec99566a2b9df3cbdfddfc5a7 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Thu, 16 Oct 2025 09:51:27 -0600 Subject: [PATCH 059/141] Added support for DistOpt with offloading with MoE's (#2264) Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/module/grouped_linear.py | 27 +++++++++++++++---- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 53a32818b..a5bf21ee1 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -209,6 +209,19 @@ def forward( if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_objects = [] + for weight in weights: + ctx.weight_objects.append(weight) + tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, *weights_fp8, @@ -271,11 +284,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - for i in range(ctx.num_gemms): - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - w.main_grad = main_grads[i] - weights[i] = w + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + for i, weight in enumerate(ctx.weight_objects): + origin_weights[i] = ctx.weight_objects[i] + ctx.weight_objects[i] = None + + if ctx.fuse_wgrad_accumulation: + for i in range(N): + origin_weights[i].main_grad = main_grads[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) From 81c363bf0939d027f03cad2d7fdea2b00ea66c16 Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Fri, 17 Oct 2025 06:45:59 +0800 Subject: [PATCH 060/141] [PyTorch] Add record_stream and untyped_storage func op in QuantizedTensor (#2144) * [PyTorch] Add record_stream and untyped_storage func op in QuantizedTensor Signed-off-by: xiaoxi-wangfj <690912414@qq.com> * Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: xiaoxi-wangfj <690912414@qq.com> * Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --------- Signed-off-by: xiaoxi-wangfj <690912414@qq.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../pytorch/tensor/float8_blockwise_tensor.py | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 16631a3d0..48762499b 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -403,6 +403,21 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) + def untyped_storage(self) -> torch.UntypedStorage: + """Return the underlying UntypedStorage of the FP8 data. + + Note that FP8 block-scaled tensor may involve multiple + buffers: row-wise FP8 data, row-wise scales, column-wise FP8 + data, column-wise scales. The UntypedStorage of the row-wise + FP8 data is returned if it exists, and otherwise the + UntypedStorage of the column-wise FP8 data. + + """ + data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data + if data is not None: + return data.untyped_storage() + return torch.UntypedStorage(0, device=self.device) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -427,6 +442,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8BlockwiseQTensor.make_like(tensor) + # record stream op + if func == torch.ops.aten.record_stream.default: + qt, stream = args + for t in ( + qt._rowwise_data, + qt._columnwise_data, + qt._rowwise_scale_inv, + qt._columnwise_scale_inv, + ): + if t is not None and t.is_cuda: + t.record_stream(stream) + return None + # Default case return super().__torch_dispatch__(func, types, args, kwargs) From 5624dbb4796124e2b23d6a295783a3158944758d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 16 Oct 2025 16:36:24 -0700 Subject: [PATCH 061/141] Changed VERSION to 2.10.0.dev0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 8bfb1cae8..c7f2fd9b8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.9.0.dev0 +2.10.0.dev0 From 9dd619222283da058f28c6c81451f2c361434c98 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 16 Oct 2025 20:45:47 -0700 Subject: [PATCH 062/141] [JAX] Fix imports in test for deprecated jax.experimental.pjit (#2274) * Fix imports in test for deprecated jax.experimental.pjit Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix: Pass NamedSharding instead of PartitionSpec to compare_ops() so that when the in and out sharding is used to create a jitted function, it has the mesh info Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Signed-off-by: Kshitij Janardan Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kshitij Janardan Lakhani --- tests/jax/distributed_test_base.py | 14 +++++++------ tests/jax/test_distributed_layernorm.py | 26 ++++++++++++++++--------- tests/jax/test_distributed_softmax.py | 10 ++++++---- 3 files changed, 31 insertions(+), 19 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 4693086b8..137fa480d 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -8,7 +8,7 @@ import pytest import jax -from jax.experimental.pjit import pjit, _UNSPECIFIED +from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource @@ -154,13 +154,15 @@ def compare_ops( grad_args = tuple(range(len(inputs))) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) - target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) - target_fwd, target_grads = target_pjitter(*inputs, **kwargs) - target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text() + target_jitter = jax.jit( + target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings + ) + target_fwd, target_grads = target_jitter(*inputs, **kwargs) + target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text() ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args) - ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) - ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs) + ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) + ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs) assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 977d010af..d551b7390 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -134,9 +134,12 @@ def ref_func(x, gamma, beta): devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): - x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) - gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) - beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) + x_named_sharding = NamedSharding(mesh, x_pspec) + g_named_sharding = NamedSharding(mesh, g_pspec) + b_named_sharding = NamedSharding(mesh, b_pspec) + x_ = jax.device_put(x, x_named_sharding) + gamma_ = jax.device_put(gamma, g_named_sharding) + beta_ = jax.device_put(beta, b_named_sharding) with warnings.catch_warnings(record=True) as warns: try: @@ -148,8 +151,11 @@ def ref_func(x, gamma, beta): grad_args=(0, 1, 2), metric_fwd_dtype=q_dtype, metric_bwd_dtype=q_dtype, - in_shardings=(x_pspec, g_pspec, b_pspec), - out_shardings=(None, (x_pspec, g_pspec, b_pspec)), + in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding), + out_shardings=( + None, + (x_named_sharding, g_named_sharding, b_named_sharding), + ), ) except AssertionError as err: # Layernorm should still produce the correct numerical result with @@ -210,8 +216,10 @@ def ref_func(x, gamma): devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): - x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) - gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) + x_named_sharding = NamedSharding(mesh, x_pspec) + g_named_sharding = NamedSharding(mesh, g_pspec) + x_ = jax.device_put(x, x_named_sharding) + gamma_ = jax.device_put(gamma, g_named_sharding) with warnings.catch_warnings(record=True) as warns: try: @@ -223,8 +231,8 @@ def ref_func(x, gamma): grad_args=(0, 1), metric_fwd_dtype=q_dtype, metric_bwd_dtype=q_dtype, - in_shardings=(x_pspec, g_pspec), - out_shardings=(None, (x_pspec, g_pspec)), + in_shardings=(x_named_sharding, g_named_sharding), + out_shardings=(None, (x_named_sharding, g_named_sharding)), ) except AssertionError as err: # RmsNorm should still produce the correct numerical result with diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 2bd4d862a..f1ae6c9e4 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -103,8 +103,10 @@ def impl_test_softmax( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, autocast(mesh_resource=mesh_resource): - x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) - mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) + x_named_sharding = NamedSharding(mesh, x_pspec) + mask_named_sharding = NamedSharding(mesh, mask_pspec) + x_ = jax.device_put(x, x_named_sharding) + mask_ = jax.device_put(mask, mask_named_sharding) with warnings.catch_warnings(record=True) as warns: try: @@ -116,8 +118,8 @@ def impl_test_softmax( grad_args=(0,), metric_fwd_dtype=dtype, metric_bwd_dtype=dtype, - in_shardings=(x_pspec, mask_pspec), - out_shardings=(None, (x_pspec,)), + in_shardings=(x_named_sharding, mask_named_sharding), + out_shardings=(None, x_named_sharding), ) except AssertionError as err: # Softmax should still produce the correct numerical result with From 05dc1e624386b38bafbd2c8bba55e3ea8eb16a49 Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Fri, 17 Oct 2025 06:48:31 -0700 Subject: [PATCH 063/141] NVFP4 Move RHT BLAS to GPU (#2275) * CUDA RHT Signed-off-by: Kevin Tong * Fix cuda graphs Signed-off-by: Kirthi Shankar Sivamani * Fix bug where RHT mask is tensor instead of int Signed-off-by: Tim Moon --------- Signed-off-by: Kevin Tong Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: Tim Moon Co-authored-by: Kirthi Shankar Sivamani Co-authored-by: Tim Moon --- transformer_engine/pytorch/tensor/nvfp4_tensor.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index ca2154f55..5e2eeed72 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -29,7 +29,7 @@ def get_no_random_sign_vector() -> torch.Tensor: """Non-random sign vector for Hadamard transform.""" - return torch.tensor([1], dtype=torch.float32) + return torch.tensor([1], dtype=torch.float32, device="cuda") def get_sign_from_vector(vector: torch.Tensor) -> int: @@ -41,7 +41,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int: mask = 0 for i, v in enumerate(vector): mask |= (v == -1) << i - return mask + return mask.item() def get_wgrad_sign_vector() -> torch.Tensor: @@ -53,6 +53,7 @@ def get_wgrad_sign_vector() -> torch.Tensor: return torch.tensor( [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], dtype=torch.float32, + device="cuda", ) @@ -81,6 +82,7 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor: [1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1], ], dtype=torch.float32, + device="cuda", ) * hadamard_scale ) @@ -94,9 +96,9 @@ def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor: signs = get_wgrad_sign_vector() else: signs = get_no_random_sign_vector() - sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32) + sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda") rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension) - return rht_matrix.to(dtype=torch.bfloat16).cuda() + return rht_matrix.to(dtype=torch.bfloat16) @functools.lru_cache(maxsize=None) From bd38004800052f88e9773d1db5c84b596f0f5861 Mon Sep 17 00:00:00 2001 From: Tim Geypens Date: Fri, 17 Oct 2025 19:21:59 +0200 Subject: [PATCH 064/141] fall back after failing ldconfig-based lib loading for cuDNN (#2277) Signed-off-by: Tim Geypens --- transformer_engine/common/__init__.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index dd1ec480b..134705f60 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -252,9 +252,7 @@ def _load_cudnn(): return handle # Attempt to locate libcudnn via ldconfig - libs = subprocess.check_output( - f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True - ) + libs = subprocess.check_output(["ldconfig", "-p"]) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: @@ -284,9 +282,7 @@ def _load_nvrtc(): return handle # Attempt to locate NVRTC via ldconfig - libs = subprocess.check_output( - f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True - ) + libs = subprocess.check_output(["ldconfig", "-p"]) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: @@ -316,9 +312,7 @@ def _load_curand(): return handle # Attempt to locate cuRAND via ldconfig - libs = subprocess.check_output( - f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True - ) + libs = subprocess.check_output(["ldconfig", "-p"]) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: From a7a69ca61c050df7bb78dc5a7f0a0077f6f57946 Mon Sep 17 00:00:00 2001 From: Haowen Zheng <157908761+Owen1B@users.noreply.github.com> Date: Sat, 18 Oct 2025 01:26:41 +0800 Subject: [PATCH 065/141] Bump up FA to 2.8.3 (#2282) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 将来 Co-authored-by: 将来 Co-authored-by: Kirthi Shankar Sivamani --- qa/L3_pytorch_FA_versions_test/test.sh | 4 ++-- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 7e9616cd0..418e824c1 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri export FLASH_ATTN_CUDA_ARCHS=$sm_arch if [ $sm_arch -gt 90 ] then - FA_versions=(2.8.1) + FA_versions=(2.8.3) elif [ $sm_arch -eq 90 ] then - FA_versions=(2.7.3 2.8.1 3.0.0b1) + FA_versions=(2.7.3 2.8.3 3.0.0b1) fi for fa_version in "${FA_versions[@]}" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index b45edc716..174d7ee9e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -115,7 +115,7 @@ class FlashAttentionUtils: version = PkgVersion("0") version_required = PkgVersion("2.1.1") version_required_blackwell = PkgVersion("2.7.3") - max_version = PkgVersion("2.8.1") + max_version = PkgVersion("2.8.3") v2_plus = False v2_1_plus = False v2_3_plus = False From c593bcefc1379a5b3d795bf913315cd209c59835 Mon Sep 17 00:00:00 2001 From: Neil Tenenholtz Date: Fri, 17 Oct 2025 13:31:21 -0400 Subject: [PATCH 066/141] Fix test of FSDP2 by correcting init logic and applying autocast (#2105) * Fix test of FSDP2 by correcting init logic and applying autocast This fixes multiple issues in the FSDP2 test, namely 1. Previously fp8 init was performed when `args.fp8_init == False`. I have updated the logic to match what I presume was intended by leveraging the nullcontext context manager. 2. `te.fp8_autocast` was previously not called; the recipe was created but was unused. The autocast context manager now wraps the model's computation. Signed-off-by: Neil Tenenholtz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix typo Signed-off-by: Neil Tenenholtz * Update tests/pytorch/distributed/run_fsdp2_model.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug when constructing context for model init Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Neil Tenenholtz Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/distributed/run_fsdp2_model.py | 23 +++++++++----------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 8026fc0a3..d3f8c82ba 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -105,23 +105,19 @@ def _train(args): fp8_format = Format.HYBRID fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - if not args.fp8_init: - # Build model context (FP8 init) - build_model_context = nullcontext - build_model_context_args = {} - + # Create build context manager + if args.fp8_init: from transformer_engine.pytorch import quantized_model_init - build_model_context = quantized_model_init - build_model_context_args["enabled"] = True - - # Build the model with the specified context - with build_model_context(**build_model_context_args): - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + build_model_context = quantized_model_init() else: + build_model_context = nullcontext() + + # Build the model with the specified context + with build_model_context: model = SimpleNet(args.input_size, args.hidden_size, args.output_size) - # Move the model to the correct device + # Move the model to the correct device model.to(device) if LOCAL_RANK == 0: @@ -163,7 +159,8 @@ def _train(args): # Zero the parameter gradients optimizer.zero_grad() input_data = torch.randn(args.batch_size, args.input_size).to(device) - output = model(input_data) + with te.autocast(enabled=True, recipe=fp8_recipe): + output = model(input_data) target = torch.randn(args.batch_size, args.output_size).to(device) loss = F.mse_loss(output, target) loss.backward() From ee384ab566709144e91c2949625bb1f1357bee50 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Fri, 17 Oct 2025 13:02:08 -0500 Subject: [PATCH 067/141] Make `CanonicalizeGemmInput()` support non-TN layout FP8 GEMM on Blackwell with column-wise/transposed data (#2233) Modified CanonicalizeGemmInput() logic to pull from column-wise data for FP8 GEMM on Blackwell when row-wise is not available. Signed-off-by: Alp Dener --- .../common/gemm/cublaslt_gemm.cu | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 84a1b735a..97e8ec9a3 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -140,6 +140,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } + } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed + // data with the mirrored transpose-flag if we don't have row-wise data. + NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), + "Input A is missing column-wise usage"); + ret.A = A.columnwise_data.dptr; + ret.transA = is_A_transposed ? CUBLAS_OP_N : CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? m : k; } if (is_fp8_dtype(ret.Atype)) { @@ -221,6 +231,16 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } + } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed + // data with the mirrored transpose-flag if we don't have row-wise data. + NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), + "Input B is missing column-wise usage"); + ret.B = B.columnwise_data.dptr; + ret.transB = is_B_transposed ? CUBLAS_OP_N : CUBLAS_OP_T; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = is_B_transposed ? k : n; } if (is_fp8_dtype(ret.Atype)) { From fd234d8006c9b9fc293f569be36d9278563e99db Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 18 Oct 2025 00:00:01 -0400 Subject: [PATCH 068/141] Wheels for cuda 13 (#2278) * Support wheel build for cuda 13 Signed-off-by: Kirthi Shankar Sivamani * Fixes Signed-off-by: Kirthi Shankar Sivamani * Fixes for cu13 runtime, format Signed-off-by: Kirthi Shankar Sivamani * Add documentation Signed-off-by: Kirthi Shankar Sivamani * Better error handling Signed-off-by: Kirthi Shankar Sivamani * fix Signed-off-by: Kirthi Shankar Sivamani * fix jax sdist Signed-off-by: Kirthi Shankar Sivamani * Modify function names Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- README.rst | 2 +- build_tools/wheel_utils/Dockerfile.aarch | 29 ++++-- build_tools/wheel_utils/Dockerfile.x86 | 29 ++++-- build_tools/wheel_utils/build_wheels.sh | 18 ++-- build_tools/wheel_utils/launch_aarch.sh | 28 ++++- build_tools/wheel_utils/launch_x86.sh | 28 ++++- docs/installation.rst | 8 ++ setup.py | 5 +- transformer_engine/common/__init__.py | 124 ++++++++++++++++------- transformer_engine/jax/setup.py | 32 +++++- transformer_engine/pytorch/setup.py | 14 ++- 11 files changed, 243 insertions(+), 74 deletions(-) diff --git a/README.rst b/README.rst index 9b65c60ae..50c1dcd80 100644 --- a/README.rst +++ b/README.rst @@ -205,7 +205,7 @@ pip Installation **Prerequisites for pip installation:** * A compatible C++ compiler -* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed +* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) if installing from source. To install the latest stable version with pip: diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index 223c4a7f1..404cb941c 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_aarch64 WORKDIR /TransformerEngine/ COPY ../.. /TransformerEngine/ -ARG VER="12-3" -ARG ARCH="aarch64" -RUN dnf -y install vim +ARG CUDA_MAJOR="12" +ARG CUDA_MINOR="3" + +# Args for build_wheels.sh +ARG BUILD_METAPACKAGE=true +ARG BUILD_COMMON=true +ARG BUILD_PYTORCH=true +ARG BUILD_JAX=true +ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE} +ENV BUILD_COMMON=${BUILD_COMMON} +ENV BUILD_PYTORCH=${BUILD_PYTORCH} +ENV BUILD_JAX=${BUILD_JAX} +ENV CUDA_MAJOR=${CUDA_MAJOR} # Cuda toolkit, cudnn, driver. RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo RUN dnf -y install epel-release -RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ - cuda-libraries-${VER}.${ARCH} \ - cuda-libraries-devel-${VER}.${ARCH} -RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \ + cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \ + cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 +RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR} RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit +RUN dnf -y install cuda-toolkit-${CUDA_MAJOR} RUN dnf clean all RUN dnf -y install glog.aarch64 glog-devel.aarch64 +RUN dnf -y install libnccl libnccl-devel libnccl-static ENV PATH="/usr/local/cuda/bin:${PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" @@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] +CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_aarch64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 26122eed9..daa7f961c 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 WORKDIR /TransformerEngine/ COPY ../.. /TransformerEngine/ -ARG VER="12-3" -ARG ARCH="x86_64" -RUN dnf -y install vim +ARG CUDA_MAJOR="12" +ARG CUDA_MINOR="3" + +# Args for build_wheels.sh +ARG BUILD_METAPACKAGE=true +ARG BUILD_COMMON=true +ARG BUILD_PYTORCH=true +ARG BUILD_JAX=true +ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE} +ENV BUILD_COMMON=${BUILD_COMMON} +ENV BUILD_PYTORCH=${BUILD_PYTORCH} +ENV BUILD_JAX=${BUILD_JAX} +ENV CUDA_MAJOR=${CUDA_MAJOR} # Cuda toolkit, cudnn, driver. RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo RUN dnf -y install epel-release -RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ - cuda-libraries-${VER}.${ARCH} \ - cuda-libraries-devel-${VER}.${ARCH} -RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \ + cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \ + cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 +RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR} RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit +RUN dnf -y install cuda-toolkit-${CUDA_MAJOR} RUN dnf clean all RUN dnf -y install glog.x86_64 glog-devel.x86_64 +RUN dnf -y install libnccl libnccl-devel libnccl-static ENV PATH="/usr/local/cuda/bin:${PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" @@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] +CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_x86_64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"] \ No newline at end of file diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index bf4f9d2bc..954a8f1c6 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -9,8 +9,10 @@ BUILD_METAPACKAGE=${2:-true} BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} BUILD_JAX=${5:-true} +CUDA_MAJOR=${6:-12} export NVTE_RELEASE_BUILD=1 +export PIP_CONSTRAINT="" export TARGET_BRANCH=${TARGET_BRANCH:-} mkdir -p /wheelhouse/logs @@ -21,7 +23,7 @@ git checkout $TARGET_BRANCH git submodule update --init --recursive # Install deps -/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja +/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1 if $BUILD_METAPACKAGE ; then cd /TransformerEngine @@ -36,32 +38,32 @@ if $BUILD_COMMON ; then # Create the wheel. /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt - # Repack the wheel for cuda specific package, i.e. cu12. + # Repack the wheel for specific cuda version. /opt/python/cp310-cp310/bin/wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info" /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" - whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" + whl_name_target="${whl_parts[0]}_cu${CUDA_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}" rm -rf $WHL_BASE dist mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then cd /TransformerEngine/transformer_engine/pytorch - /opt/python/cp310-cp310/bin/pip install torch + /opt/python/cp310-cp310/bin/pip install torch /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt cp dist/* /wheelhouse/ fi if $BUILD_JAX ; then cd /TransformerEngine/transformer_engine/jax - /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib + /opt/python/cp310-cp310/bin/pip install "jax[cuda${CUDA_MAJOR}_local]" jaxlib /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt cp dist/* /wheelhouse/ fi diff --git a/build_tools/wheel_utils/launch_aarch.sh b/build_tools/wheel_utils/launch_aarch.sh index 04e3cd691..85f754ca1 100644 --- a/build_tools/wheel_utils/launch_aarch.sh +++ b/build_tools/wheel_utils/launch_aarch.sh @@ -2,7 +2,29 @@ # # See LICENSE for license information. -docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . +# Remove leftovers. +rm -rf aarch_wheelhouse_cu12 aarch_wheelhouse_cu13 + +# CUDA 12. +docker build --no-cache \ + --build-arg CUDA_MAJOR=12 \ + --build-arg CUDA_MINOR=3 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . +docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" +docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu12 + +# CUDA 13. +docker build --no-cache \ + --build-arg CUDA_MAJOR=13 \ + --build-arg CUDA_MINOR=0 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" -rm -rf aarch_wheelhouse -docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse +docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu13 diff --git a/build_tools/wheel_utils/launch_x86.sh b/build_tools/wheel_utils/launch_x86.sh index b0d20be3f..11fc52294 100644 --- a/build_tools/wheel_utils/launch_x86.sh +++ b/build_tools/wheel_utils/launch_x86.sh @@ -2,7 +2,29 @@ # # See LICENSE for license information. -docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . +# Remove leftovers. +rm -rf x86_wheelhouse_cu12 x86_wheelhouse_cu13 + +# CUDA 12. +docker build --no-cache \ + --build-arg CUDA_MAJOR=12 \ + --build-arg CUDA_MINOR=3 \ + --build-arg BUILD_METAPACKAGE=true \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=true \ + --build-arg BUILD_JAX=true \ + -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . +docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" +docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu12 + +# CUDA 13. +docker build --no-cache \ + --build-arg CUDA_MAJOR=13 \ + --build-arg CUDA_MINOR=0 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" -rm -rf x86_wheelhouse -docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse +docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu13 diff --git a/docs/installation.rst b/docs/installation.rst index ecb1e9a0d..a8bb74fd1 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -38,6 +38,14 @@ Transformer Engine can be directly installed from `our PyPI Tuple[List[str], List[str]]: ext_modules = [] package_data = {} include_package_data = False - install_requires = ([f"transformer_engine_cu12=={__version__}"],) + install_requires = [] extras_require = { + "core": [f"transformer_engine_cu12=={__version__}"], + "core_cu12": [f"transformer_engine_cu12=={__version__}"], + "core_cu13": [f"transformer_engine_cu13=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 134705f60..3ffe1c7b1 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -8,22 +8,18 @@ import functools import glob import importlib -from importlib.metadata import version, metadata, PackageNotFoundError -import logging +from importlib.metadata import version, distribution, PackageNotFoundError import os from pathlib import Path import platform import subprocess import sys import sysconfig -from typing import Optional - - -_logger = logging.getLogger(__name__) +from typing import Optional, Tuple @functools.lru_cache(maxsize=None) -def _is_pip_package_installed(package) -> bool: +def _is_package_installed(package) -> bool: """Check if the given package is installed via pip.""" # This is needed because we only want to return true @@ -31,12 +27,34 @@ def _is_pip_package_installed(package) -> bool: # if it's importable in the current directory due to # the presence of the shared library module. try: - metadata(package) + distribution(package) except PackageNotFoundError: return False return True +@functools.lru_cache(maxsize=None) +def _is_package_installed_from_wheel(package) -> bool: + """Check if the given package is installed via PyPI.""" + + if not _is_package_installed(package): + return False + + te_dist = distribution(package) + te_wheel_file = "" + for file_path in te_dist.files: + if file_path.name == "WHEEL": + te_wheel_file = te_dist.locate_file("") / file_path + if not te_wheel_file: + return False + + with te_wheel_file.open("r") as f: + for line in f: + if line.startswith("Root-Is-Purelib:"): + return line.strip().split(":")[1].strip().lower() == "true" + return False + + @functools.lru_cache(maxsize=None) def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]: """ @@ -112,6 +130,19 @@ def _get_shared_object_file(library: str) -> Path: ) +def get_te_core_package_info() -> Tuple[bool, str, str]: + """ + Check if Tranformer Engine core package is installed. + Returns the module name and version if found. + """ + + te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") + for package in te_core_packages: + if _is_package_installed(package): + return True, package, version(package) + return False, "", "" + + @functools.lru_cache(maxsize=None) def load_framework_extension(framework: str) -> None: """ @@ -130,39 +161,30 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" + # Find the TE packages. The core and framework packages can only be installed via PyPI. + # For the `transformer-engine` package, we need to check explicity. + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() + te_framework_installed = _is_package_installed(module_name) + te_installed = _is_package_installed("transformer_engine") + te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") + + assert te_installed, "Could not find `transformer_engine`." + # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework - # extension are all installed via PyPI and have matching version. - if _is_pip_package_installed(module_name): - assert _is_pip_package_installed( - "transformer_engine" - ), "Could not find `transformer-engine`." - assert _is_pip_package_installed( - "transformer_engine_cu12" - ), "Could not find `transformer-engine-cu12`." - assert ( - version(module_name) - == version("transformer-engine") - == version("transformer-engine-cu12") - ), ( - "TransformerEngine package version mismatch. Found" + # extension are all installed via PyPI and have matching versions. + if te_framework_installed: + assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`." + + assert version(module_name) == version("transformer-engine") == te_core_version, ( + "Transformer Engine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-cu12" - f" v{version('transformer-engine-cu12')}. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" + f" v{version('transformer-engine')}, and {te_core_package_name}" + f" v{te_core_version}. Install transformer-engine using " + f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'" ) - # If the core package is installed via PyPI, log if - # the framework extension is not found from PyPI. - # Note: Should we error? This is a rare use case. - if _is_pip_package_installed("transformer-engine-cu12"): - if not _is_pip_package_installed(module_name): - _logger.info( - "Could not find package %s. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'", - module_name, - ) - # After all checks are completed, load the shared object file. spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) solib = importlib.util.module_from_spec(spec) @@ -170,6 +192,35 @@ def load_framework_extension(framework: str) -> None: spec.loader.exec_module(solib) +def sanity_checks_for_pypi_installation() -> None: + """Ensure that package is installed correctly if using PyPI.""" + + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() + te_installed = _is_package_installed("transformer_engine") + te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") + + assert te_installed, "Could not find `transformer-engine`." + + # If the core package is installed via PyPI. + if te_core_installed: + assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + assert version("transformer-engine") == te_core_version, ( + "Transformer Engine package version mismatch. Found " + f"transformer-engine v{version('transformer-engine')} " + f"and {te_core_package_name} v{te_core_version}." + ) + + # Only the metapackage is found, invalid usecase. + elif te_installed_via_pypi: + raise RuntimeError( + "Found empty `transformer-engine` meta package installed. " + "Install `transformer-engine` with framework extensions via" + "'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'" + " or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`" + " or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib." + ) + + @functools.lru_cache(maxsize=None) def _get_sys_extension() -> str: """File extension for shared objects.""" @@ -332,6 +383,7 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): + sanity_checks_for_pypi_installation() _CUDNN_LIB_CTYPES = _load_cudnn() _NVRTC_LIB_CTYPES = _load_nvrtc() _CURAND_LIB_CTYPES = _load_curand() diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index f83375d82..ccdbcdb52 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -54,6 +54,26 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True) +def get_cuda_major_version() -> int: + """Get CUDA major version using Jax backend.""" + + assert ( + jax._src.lib.cuda_versions is not None + ), "GPU backend is required to build TE jax extensions." + + # Jax currently does not have any stable/public method to get cuda version. + # Try using internal function and default to cuda12 if not found. + try: + cuda_version = jax._src.lib.cuda_versions.cuda_runtime_get_version() + cuda_major_version = cuda_version // 1000 + except AttributeError: + cuda_version = os.getenv("CUDA_VERSION", "12") + cuda_major_version = int(cuda_version.split(".")[0]) + + assert cuda_major_version in (12, 13), f"Unsupported cuda version {cuda_version}." + return cuda_major_version + + if __name__ == "__main__": """Main entry point for JAX extension installation. @@ -93,15 +113,23 @@ ) ] + # Setup version and requirements. + # Having the framework extension depend on the core lib allows + # us to detect CUDA version dynamically during compilation and + # choose the correct wheel for te core lib. + __version__ = te_version() + te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" + install_requires = install_requirements() + [te_core] + # Configure package setuptools.setup( name="transformer_engine_jax", - version=te_version(), + version=__version__, description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, python_requires=f">={min_python_version_str()}", - install_requires=install_requirements(), + install_requires=install_requires, tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 08870040f..7a8155004 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -145,15 +145,25 @@ def run(self): ) ] + # Setup version and requirements. + # Having the framework extension depend on the core lib allows + # us to detect CUDA version dynamically during compilation and + # choose the correct wheel for te core lib. + __version__ = te_version() + cuda_major_version = parse(torch.version.cuda).major + assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}." + te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" + install_requires = install_requirements() + [te_core] + # Configure package setuptools.setup( name=PACKAGE_NAME, - version=te_version(), + version=__version__, description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, python_requires=f">={min_python_version_str()}", - install_requires=install_requirements(), + install_requires=install_requires, tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): From dd7ab715a55d18740de5f10546ac71842f832e07 Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Mon, 20 Oct 2025 23:02:05 +0800 Subject: [PATCH 069/141] Fix error with triton 3.5 (#2286) * Update permutation.py Signed-off-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> * Update permutation.py Signed-off-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> * Update transformer_engine/pytorch/triton/permutation.py Signed-off-by: Kirthi Shankar Sivamani * Update transformer_engine/pytorch/triton/permutation.py Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/triton/permutation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 6292acb69..1474a664c 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -12,11 +12,16 @@ from triton.language import core from triton.language.standard import _log2 +from packaging import version # The following three argsort related kernels are adapted from # the issue https://github.com/triton-lang/triton/issues/3698 +get_int_dtype = core.get_int_dtype +if version.parse(triton.__version__) >= version.parse("3.5.0"): + get_int_dtype = triton.constexpr_function(get_int_dtype) + @triton.jit def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): @@ -37,7 +42,7 @@ def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) - idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) il_value = l_value.to(idtype, bitcast=True) ir_value = r_value.to(idtype, bitcast=True) From bd55e7ba5f0235a80eaa63d49adaa8fb7c6ced50 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 20 Oct 2025 16:28:23 -0400 Subject: [PATCH 070/141] [PyTorch] Fix CI failures due to deterministic attention backend (#2288) * Fix CI failures due to deterministic attention Signed-off-by: Kirthi Shankar Sivamani * some more cleanup Signed-off-by: Kirthi Shankar Sivamani * Fix debug test Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_debug_unittest/test.sh | 2 +- qa/L0_pytorch_unittest/test.sh | 4 +-- tests/pytorch/test_numerics.py | 30 +------------------ .../attention/dot_product_attention/utils.py | 2 +- 4 files changed, 5 insertions(+), 33 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index 7f19dda67..9980ccfb0 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -32,6 +32,6 @@ pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/ # standard sanity and numerics tests with initialized debug NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 exit $FAIL diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index cdf0df888..b23ce3b6c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -27,8 +27,8 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index bef076a38..35698b819 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -43,11 +43,10 @@ ) from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.common import recipe import transformer_engine_torch as tex -from utils import ModelConfig, reset_rng_states, get_available_attention_backends +from utils import ModelConfig, reset_rng_states # Only run FP8 tests on supported devices. @@ -130,23 +129,6 @@ use_cutlass_grouped_gemm.append(True) -def is_fused_attn_available( - config: ModelConfig, - dtype: torch.dtype, - qkv_layout="bshd_bshd_bshd", - is_training=True, - deterministic=False, -): - _, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - deterministic=deterministic, - ) - return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends - - def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -853,8 +835,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] - if not is_fused_attn_available(config, dtype, deterministic=True): - pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -901,10 +881,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] - if not is_fused_attn_available( - config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True - ): - pytest.skip("No attention backend available.") te_gpt = TransformerLayer( hidden_size=config.hidden_size, @@ -1016,10 +992,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] - if not is_fused_attn_available( - config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True - ): - pytest.skip("No attention backend available.") te_mha = MultiheadAttention( config.hidden_size, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 174d7ee9e..4cb39cda0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -983,7 +983,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0): + if is_training and device_compute_capability >= (10, 0): logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") use_fused_attention = False fused_attention_backend = None From b4a1d4d6f4f00a3b30d305c72cd040ae95ea41e4 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Mon, 20 Oct 2025 21:15:16 -0700 Subject: [PATCH 071/141] [PyTorch][MOE] Support NVFP4 Grouped Linear (#2215) * pipeclean, fix nvfp4 padding of 32 alignment Signed-off-by: Zhongbo Zhu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * numerical test passed Signed-off-by: Zhongbo Zhu * fix CI failure with test_cast_master_weights_to_fp8 (in a hacky way) Signed-off-by: Zhongbo Zhu * found CUDA mis-aligned address error in training in multi-swizzle, hack the vec_load_size to 1 to unblock Signed-off-by: Zhongbo Zhu * leave comments about alignment issue Signed-off-by: Zhongbo Zhu * fused bulk alloc nvfp4 Signed-off-by: Zhongbo Zhu * fix RHT sign mask CPU overhead Signed-off-by: Zhongbo Zhu * fix Signed-off-by: Zhongbo Zhu * resolve comments Signed-off-by: Zhongbo Zhu * Remove incorrect logic that treats 0-D tensor as uninitialized Tensor shape logic still requires treating 0-D tensor as uninitialized. Signed-off-by: Tim Moon * Fix invalid conversion from tensor to int Signed-off-by: Tim Moon --------- Signed-off-by: Zhongbo Zhu Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- benchmarks/linear/benchmark_grouped_linear.py | 72 ++++-- tests/pytorch/test_numerics.py | 93 +++++++- transformer_engine/common/common.h | 27 ++- transformer_engine/common/swizzle/swizzle.cu | 83 +++++-- transformer_engine/pytorch/csrc/common.cpp | 5 +- .../pytorch/csrc/extensions/cast.cpp | 213 +++++++++++++++++- .../pytorch/csrc/extensions/recipe.cpp | 3 +- transformer_engine/pytorch/csrc/quantizer.cpp | 10 +- transformer_engine/pytorch/csrc/util.cpp | 59 ++--- .../pytorch/module/fp8_padding.py | 11 +- .../pytorch/module/fp8_unpadding.py | 15 +- 11 files changed, 504 insertions(+), 87 deletions(-) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 48adb2a10..d4bbad75c 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -8,53 +8,67 @@ import pandas as pd from transformer_engine.pytorch.module import GroupedLinear -from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager from contextlib import nullcontext """ # Profile BF16 recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \ + --output=./benchmarks/linear/b200_numgemm_8_bf16 \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16 # Profile FP8 sub-channel recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \ + --output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel # Profile MXFP8 recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \ + --output=./benchmarks/linear/b200_numgemm_8_mxfp8 \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8 +# Profile NVFP4 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_numgemm_8_nvfp4 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 + """ RECIPES = { "bf16": None, "fp8_sub_channel": Float8BlockScaling(), "mxfp8": MXFP8BlockScaling(), + "nvfp4": NVFP4BlockScaling(), } mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( FP8GlobalStateManager.is_fp8_block_scaling_available() ) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): assert mode in ["fwd_only", "fwd_bwd"] - fp8_context = autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() - # print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") + quantization_context = ( + autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext() + ) if mode == "fwd_only": - with torch.no_grad(), fp8_context: + with torch.no_grad(), quantization_context: for i in range(run_num_steps): y_q = layer.forward( x, @@ -67,7 +81,7 @@ def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps= layer.zero_grad() x.grad = None - with fp8_context: + with quantization_context: for i in range(run_num_steps): label = f"step_{i}" torch.cuda.nvtx.range_push(label) @@ -142,7 +156,7 @@ def benchmark_linear( "recipe": recipe, }, num_threads=1, - ).blocked_autorange(min_run_time=5) + ).blocked_autorange(min_run_time=10) print(f"{recipe_name}: {timing} \n") timing_ms = timing.median * 1000 / num_microbatches @@ -225,30 +239,44 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): use_bias = False # Set the MKN values to benchmark + # Deepseek V3 EP64, SEQ_LEN=8192, topK8 + # 256 expert => 4 local experts + # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384 + # M = AvgM * localExperts = 65536 + # K = 7168 + # N = 2048 + + # Deepseek V3 EP32, SEQ_LEN=8192, topK8 + # 256 expert => 8 local experts + # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192 + # M = AvgM * localExperts = 65536 + # K = 7168 + # N = 2048 + + # 4 or 8local experts per rank + num_gemms_list = [4, 8] + + # MKN for group linear mkns = [] - for m in [8192]: - # for m in [4096, 8192, 16384]: - # for n in [1024, 2048, 4096, 8192, 16384]: - for n in [8192]: - for k in [4096]: + for m in [65536]: + for k in [7168]: + for n in [2048]: mkns.append((m, k, n)) # default recipes to run if not specified recipe_list = ["bf16"] if args.recipe == "all": - recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"] + recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"] else: recipe_list = [args.recipe] - num_gemms_list = [8] - if args.profile: - mkns = [(4096 * 8, 4096, 4096)] + mkns = [(8192 * 8, 7168, 2048)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" - " fp8_sub_channel, mxfp8, or bf16" + " fp8_sub_channel, mxfp8, nvfp4, or bf16" ) recipe_list = [args.recipe] num_gemms_list = [8] @@ -265,13 +293,17 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): "bf16", "fp8_sub_channel", "mxfp8", - ], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8" + "nvfp4", + ], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4" if recipe_name == "mxfp8" and not mxfp8_available: print(f"MXFP8 is not available, skipping {recipe_name}") continue if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: print(f"FP8 block scaling is not available, skipping {recipe_name}") continue + if recipe_name == "nvfp4" and not nvfp4_available: + print(f"NVFP4 is not available, skipping {recipe_name}") + continue df = run_benchmark_linear( mkns, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 35698b819..01f1deb98 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -40,6 +40,7 @@ is_mxfp8_available, is_fp8_block_scaling_available, is_bf16_available, + is_nvfp4_available, ) from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm @@ -53,6 +54,7 @@ fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) fp8_block_scaling_available = is_fp8_block_scaling_available() +nvfp4_available = is_nvfp4_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -114,6 +116,43 @@ ) +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) @@ -122,6 +161,8 @@ if fp8_available: fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +if nvfp4_available: + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper @@ -582,6 +623,11 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) config = model_configs[model] @@ -692,6 +738,11 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) config = model_configs[model] @@ -1275,6 +1326,12 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): te_linear_ref = Linear( config.hidden_size, @@ -1718,8 +1775,8 @@ def _test_grouped_linear_accuracy( split_size = 1 if fp8: split_size = 16 - if recipe.mxfp8(): - split_size = 128 + if recipe.mxfp8() or recipe.nvfp4(): + split_size = 32 m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero @@ -1791,6 +1848,12 @@ def test_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, @@ -1927,6 +1990,12 @@ def test_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, @@ -2014,7 +2083,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): align_size = 16 - if recipe.mxfp8(): + if recipe.mxfp8() or recipe.nvfp4(): align_size = 32 padded_tokens_per_expert = [ (num_tokens + align_size - 1) // align_size * align_size @@ -2129,6 +2198,12 @@ def test_padding_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, @@ -2200,6 +2275,12 @@ def test_padding_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, @@ -2409,6 +2490,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + config = model_configs[model] outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index bddd9bf19..97b130952 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -183,21 +183,38 @@ struct Tensor { * https://gcc.gnu.org/bugzilla/show_bug.cgi?id=109569). */ switch (scaling_mode) { - case NVTE_NVFP4_1D_SCALING: case NVTE_DELAYED_TENSOR_SCALING: - if (!has_data() && has_columnwise_data()) { + case NVTE_NVFP4_1D_SCALING: { + // Choose data buffer based on whether it is initialized + // Note: Uninitialized buffers currently have shape=[]. + // However, this is logically incorrect. 0-D tensors have 1 + // entry, and uninitialized tensors should have shape=[0]. + bool use_columnwise_shape = false; + if (data.dptr != nullptr) { + use_columnwise_shape = false; + } else if (columnwise_data.dptr != nullptr) { + use_columnwise_shape = true; + } else if (data.shape.size() != 0) { + use_columnwise_shape = false; + } else if (columnwise_data.shape.size() != 0) { + use_columnwise_shape = true; + } + + // Infer shape based on data + if (use_columnwise_shape) { + // Column-wise data is transposed std::vector ret; if (!columnwise_data.shape.empty()) { + ret.reserve(columnwise_data.shape.size()); for (size_t i = 1; i < columnwise_data.shape.size(); i++) { ret.push_back(columnwise_data.shape[i]); } ret.push_back(columnwise_data.shape.front()); } return ret; - } else { - return data.shape; } - break; + return data.shape; + } case NVTE_MXFP8_1D_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 36e06173d..06735e310 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -332,11 +332,9 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - NVTE_CHECK(input->scaling_mode == NVTE_MXFP8_1D_SCALING || - input->scaling_mode == NVTE_BLOCK_SCALING_1D || - input->scaling_mode == NVTE_BLOCK_SCALING_2D || - input->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + NVTE_CHECK( + input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); @@ -583,16 +581,19 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, NVTE_CHECK_CUDA(cudaGetLastError()); } -// TODO(nvfp4): Add NVFP4 support. void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; + bool all_nvfp4 = true; for (size_t i = 0; i < num_tensors; i++) { - if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); - } + auto scaling_mode = input[i]->scaling_mode; + auto is_fp8 = is_fp8_dtype(input[i]->dtype()); + auto is_fp4 = is_fp4_dtype(input[i]->dtype()); + NVTE_CHECK( + (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), + "Not implemented scaling mode " + to_string(scaling_mode) + "."); // We don't allow empty tensors. They should be filtered out before calling this function. if (input[i]->data.numel() == 0) { NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); @@ -601,13 +602,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); all_has_data &= input[i]->has_data(); all_has_columnwise_data &= input[i]->has_columnwise_data(); + all_nvfp4 &= is_nvfp4_scaling(scaling_mode); } NVTE_CHECK(all_has_data || all_has_columnwise_data, "All tensors should have data or columnwise data."); + const bool rowwise_swizzle = all_has_data || all_nvfp4; + const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; + constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_K = 4; - if (all_has_data) { + if (rowwise_swizzle) { MultiSwizzleArgs kernel_args; kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; @@ -623,29 +628,60 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.num_tensors = 0; vec_load_size = 4; } - const int m = input[i]->scale_inv.shape[0]; - const int k = input[i]->scale_inv.shape[1]; + + int m, k; + + if (all_has_data) { + m = input[i]->scale_inv.shape[0]; + k = input[i]->scale_inv.shape[1]; + } else { + NVTE_CHECK(all_nvfp4, "When doing rowwise swizzle with rowwise data, it has to be NVFP4"); + m = input[i]->columnwise_scale_inv.shape[0]; + k = input[i]->columnwise_scale_inv.shape[1]; + } NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - NVTE_CHECK( - m * k == std::accumulate(output[i]->scale_inv.shape.begin(), - output[i]->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); + + if (output[i]->has_data()) { + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output[i]->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } int num_tiles_k = k / SF_TILE_DIM_K; int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; // We use the minimum vec_load_size across all tensors. - vec_load_size = std::min(vec_load_size, vec_load_size_i); + // TODO(zhongbo): fix vec_load_size for NVFP4 + // Current unit test won't capture this issue, but in E2E + // using vec_load_size = 1 other than 1 will lead to mis-aligned + // address error in MOE training + vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i); const int pos = kernel_args.num_tensors; - kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); - kernel_args.output_list[pos] = output[i]->scale_inv.dptr; kernel_args.m_list[pos] = m; kernel_args.k_list[pos] = k; - kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); - kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + if (!all_nvfp4 || all_has_data) { + int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size; + } else { + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE; + } kernel_args.num_tensors++; } // Launch the remaining tensors @@ -655,7 +691,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args, vec_load_size, true, stream); } - if (all_has_columnwise_data) { + if (columnwise_swizzle) { + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); + MultiSwizzleArgs kernel_args; kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 49ae963d7..e054424dd 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -190,8 +190,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector meta_shape{1}; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, columnwise_scale_inv_shape); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b6e9ef828..7d15e436e 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -491,6 +491,207 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } +// allocate fp4 data, fp8 scalings, and amax values +// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] +// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate +std::tuple, std::vector> bulk_allocate_nvfp4_tensors( + std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { + init_extension(); + std::tuple, std::vector> retval; + auto &tensor_py_list = std::get<0>(retval); + auto &tensor_cpp_list = std::get<1>(retval); + + // Number of tensors + const size_t num_tensors = shape_list.size(); + if (num_tensors == 0) { + return retval; + } + + // Quantization parameters + const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); + const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + constexpr size_t scale_elem_size = 1; + + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + size_t offset, at::ScalarType dtype) -> at::Tensor { + std::vector shape_int64(shape.begin(), shape.end()); + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { + return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); + } + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); + }; + + // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const std::vector &shape) { + std::vector fp4_shape(shape.begin(), shape.end()); + if (!fp4_shape.empty()) { + fp4_shape.back() /= 2; + } + return fp4_shape; + }; + + // Allocate row-wise data + std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; + std::vector> rowwise_data_shapes, rowwise_scale_shapes; + if (rowwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_shapes.emplace_back(shape_list[i]); + rowwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets, amax_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + // Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes). + // Integer arithmetic: ceil(product / 2) == (product + 1) / 2. + buffer_size += (product(rowwise_data_shapes[i]) + 1) / 2; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + amax_offsets.push_back(buffer_size); + // amax is scalar in fp32, 4 bytes each + buffer_size += 4; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), + data_offsets[i], torch::kUInt8)); + rowwise_scale_list.emplace_back( + make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + amax_rowwise_list.emplace_back( + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + } + } + + // Allocate column-wise data + std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; + std::vector> columnwise_data_shapes, columnwise_scale_shapes; + if (columnwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + // push the transposed shape into NVFP4 columnwise shape + // NVFP4 on SM100 is TN only + columnwise_data_shapes.emplace_back(); + auto &shape = columnwise_data_shapes.back(); + shape.push_back(shape_list[i].back()); + for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { + shape.push_back(shape_list[i][j]); + } + columnwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets, amax_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + // Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes). + // Integer arithmetic: ceil(product / 2) == (product + 1) / 2. + buffer_size += (product(columnwise_data_shapes[i]) + 1) / 2; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + amax_offsets.push_back(buffer_size); + // amax is scalar in fp32, 4 bytes each + buffer_size += 4; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_list.emplace_back(make_torch_view( + buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); + columnwise_scale_list.emplace_back( + make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + amax_columnwise_list.emplace_back( + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + } + } + + // Construct nvfp4 tensors + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); + for (size_t i = 0; i < num_tensors; ++i) { + // Create tensor objects with proper reference counting + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); + py::object columnwise_data = + (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); + py::object columnwise_scale = + (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none(); + py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); + + // Construct Python tensor + tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, amax_rowwise, amax_columnwise, + fp4_dtype, quantizer_py_list[i])); + + // Construct C++ tensor + // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, + // then set the amax and amax_columnwise values. + { + auto tensor_wrapper = makeTransformerEngineTensor( + rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp4_dtype, + /*amax_ptr=*/nullptr, + /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode); + + // Set the amax rowwise and amax columnwise if available + if (rowwise_usage) { + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (columnwise_usage) { + tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, + std::vector{1}); + } + tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); + } + } + + return retval; +} + } // namespace std::vector split_quantize(const at::Tensor &tensor, @@ -549,7 +750,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool use_fused_bulk_alloc = true; for (size_t i = 0; i < quantizer_list.size(); i++) { if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) && - !detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) { + !detail::IsMXFP8Quantizers(quantizer_list[i].ptr()) && + !detail::IsNVFP4Quantizers(quantizer_list[i].ptr())) { use_fused_bulk_alloc = false; break; } @@ -570,6 +772,7 @@ std::vector split_quantize(const at::Tensor &tensor, // TODO(zhongbo): make a better api to make this part less hacky bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr()); bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr()); + bool is_nvfp4 = detail::IsNVFP4Quantizers(quantizer_list[0].ptr()); if (is_fp8_blockwise) { // FP8 block-scaling: construct output tensors with bulk allocations std::vector blockwise_quantizers; @@ -586,6 +789,14 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); + } else if (is_nvfp4) { + // NVFP4: construct output tensors with bulk allocations + std::vector nvfp4_quantizers; + for (auto &quantizer : quantizer_cpp_list) { + nvfp4_quantizers.push_back(static_cast(quantizer.get())); + } + std::tie(output_py_list, output_cpp_list) = + bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); } else { NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); } diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 3635d4a9c..8d1d86560 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -20,10 +20,11 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); + auto* amax_ptr = amax.data_ptr(); TensorWrapper fake_te_output( nullptr, te_input.shape(), DType::kFloat8E4M3, // It doesn't matter because we only compute amax. - amax.data_ptr()); + amax_ptr); nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 42ae658f2..d7e8912ac 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1200,6 +1200,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({1}, bit32_tensor_opts); } if (columnwise_usage) { @@ -1213,6 +1215,8 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_data_tensor = at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed amax_columnwise = at::empty({1}, bit32_tensor_opts); } @@ -1352,6 +1356,8 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_rowwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed amax_rowwise = at::empty({1}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } @@ -1392,7 +1398,9 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } if (!amax_columnwise) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); - amax_columnwise = at::zeros({1}, opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed + amax_columnwise = at::empty({1}, opts); tensor.attr("_amax_columnwise") = *amax_columnwise; } } else { // columnwise_usage == false diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index ffba5b276..134185ac8 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -50,8 +50,6 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap void* scale_inv_dptr = scale_inv.data_ptr; void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); - // Reconstruct input only to avoid swizzling both directions if not needed. - // The specific dtype used is irrelevant, just needs to be correct bits. transformer_engine::TensorWrapper input_cu(input.scaling_mode()); transformer_engine::TensorWrapper output_cu(input.scaling_mode()); @@ -100,10 +98,14 @@ std::optional multi_tensor_swizzle_scaling_factors( if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING && + tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } + const auto scaling_mode = tensors.front().scaling_mode(); + const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; + std::vector wrappers; std::vector input_tensors, output_tensors; @@ -131,39 +133,44 @@ std::optional multi_tensor_swizzle_scaling_factors( // Allocate full buffer auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + for (size_t i = 0; i < tensors.size(); ++i) { auto& tensor = tensors[i]; void* scale_inv_dptr = scale_inv_dptrs[i]; void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); - auto input_shape = nvte_shape_to_vector(tensor.shape()); - + // auto input_shape = nvte_shape_to_vector(tensor.shape()); + NVTEShape nvte_input_shape; + if (rowwise) { + nvte_input_shape = tensor.shape(); + } else { + nvte_input_shape = tensor.get_columnwise_data().shape; + } + auto input_shape = nvte_shape_to_vector(nvte_input_shape); // Reconstruct input only to avoid swizzling both directions if not needed. // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper input_cu(scaling_mode); + transformer_engine::TensorWrapper output_cu(scaling_mode); if (rowwise) { - input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); - output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); // Set the swizzled scaling factor to the original tensor. - tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); } else { - input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); - output_cu.set_columnwise_data(tensor.columnwise_dptr(), - transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_columnwise_scale_inv( - swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); // Set the swizzled scaling factor to the original tensor. - tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); } input_tensors.emplace_back(input_cu.data()); diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 5d569d59d..fca89fbaa 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -78,7 +78,7 @@ class Fp8Padding(torch.nn.Module): number of GEMMs to be performed simultaneously. align_size : int, optional the alignment size for the input tensor. If not provided, the alignment size will - be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first + be determined by the FP8/FP4 recipe (32 for MXFP8/NVFP4 and 16 for others) in the first forward pass. """ @@ -111,7 +111,14 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + self.align_size = ( + 32 + if ( + FP8GlobalStateManager.get_fp8_recipe().mxfp8() + or FP8GlobalStateManager.get_fp8_recipe().nvfp4() + ) + else 16 + ) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index b74395dd8..7a01f1572 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module): num_gemms : int number of GEMMs to be performed simultaneously. align_size : int, optional - the alignment size for the input tensor. If not provided, the alignment size will - be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first - forward pass. + The alignment size for the input tensor. If not provided, the alignment size will + be automatically determined based on the FP8/FP4 recipe in the first forward pass: + 32 for MXFP8 or NVFP4, otherwise 16. """ def __init__( @@ -109,7 +109,14 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + self.align_size = ( + 32 + if ( + FP8GlobalStateManager.get_fp8_recipe().mxfp8() + or FP8GlobalStateManager.get_fp8_recipe().nvfp4() + ) + else 16 + ) # FP8 padding calculate padded_m_splits = [ From e90582f2010deae477a71bad0aacf278dd5abfa4 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Tue, 21 Oct 2025 19:52:37 +0200 Subject: [PATCH 072/141] [Common] Removed activations from NVFP4 quantize C++ unit tests (#2289) * Removed activations from NVFP4 CPP tests. Removed CMake debugging flags Signed-off-by: Oleg Goncharov * Better wording Signed-off-by: Oleg Goncharov --------- Signed-off-by: Oleg Goncharov --- tests/cpp/operator/CMakeLists.txt | 6 ------ tests/cpp/operator/test_cast_nvfp4_transpose.cu | 9 ++------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 479d378ba..b2f14b189 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -32,12 +32,6 @@ add_executable(test_operator test_swap_first_dims.cu ../test_common.cu) -# Add profiling and debug flags for CUDA compilation -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -lineinfo") # Generate line info for device code -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -g") # Add debug symbols for host code -set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --ptxas-options=-v") # Add info about registers usage -# Note: Using -lineinfo instead of -G to avoid conflicts and get line mapping - # Find required packages find_package(OpenMP REQUIRED) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index e905a0064..afd7927da 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -661,14 +661,9 @@ std::vector> tensor_dims = { {4096, 13312}, }; -// Only GeLU activation tests are supported +// Only the Identity activation is currently supported. std::vector Activation_types = { - ActivationType::Identity, - ActivationType::GeLU, - ActivationType::SiLU, - ActivationType::ReLU, - ActivationType::QGeLU, - ActivationType::SReLU, + ActivationType::Identity }; } // namespace From ce2f9fa4632a688d45efc55586be18cd0931ea50 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 21 Oct 2025 12:29:09 -0700 Subject: [PATCH 073/141] [JAX] HuggingFace login in JAX examples if token is available (#2290) HF login in JAX examples Signed-off-by: Jeremy Berchtold --- examples/jax/encoder/common.py | 11 +++++++++++ examples/jax/encoder/test_model_parallel_encoder.py | 2 ++ examples/jax/encoder/test_multigpu_encoder.py | 7 ++++++- examples/jax/encoder/test_multiprocessing_encoder.py | 2 ++ examples/jax/encoder/test_single_gpu_encoder.py | 7 ++++++- examples/jax/mnist/test_single_gpu_mnist.py | 8 +++++++- 6 files changed, 34 insertions(+), 3 deletions(-) diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index 772d5f4c1..9ffcfe57d 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -118,3 +118,14 @@ def get_quantization_recipe_from_name_string(name: str): return recipe.NVFP4BlockScaling() case _: raise ValueError(f"Invalid quantization_recipe, got {name}") + + +def hf_login_if_available(): + """Login to HF hub if available""" + try: + from huggingface_hub import login + + login() + except Exception as e: + print(e) + pass diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 7807d1fd9..c6d867ef9 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -23,12 +23,14 @@ is_bf16_supported, get_quantization_recipe_from_name_string, assert_params_sufficiently_sharded, + hf_login_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode +hf_login_if_available() DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 8ea1dcde3..1004dd2dd 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -19,12 +19,17 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding -from common import is_bf16_supported, get_quantization_recipe_from_name_string +from common import ( + is_bf16_supported, + get_quantization_recipe_from_name_string, + hf_login_if_available, +) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode +hf_login_if_available() DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 7e708466c..c2e97029b 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -27,11 +27,13 @@ is_mxfp8_supported, is_nvfp4_supported, get_quantization_recipe_from_name_string, + hf_login_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax +hf_login_if_available() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 79178485c..1c62de7fa 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -16,11 +16,16 @@ from flax import linen as nn from flax.training import train_state -from common import is_bf16_supported, get_quantization_recipe_from_name_string +from common import ( + is_bf16_supported, + get_quantization_recipe_from_name_string, + hf_login_if_available, +) import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode +hf_login_if_available() PARAMS_KEY = "params" DROPOUT_KEY = "dropout" diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index d0aebeb53..2e9d56e93 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -22,7 +22,13 @@ DIR = str(Path(__file__).resolve().parents[1]) sys.path.append(str(DIR)) -from encoder.common import is_bf16_supported, get_quantization_recipe_from_name_string +from encoder.common import ( + is_bf16_supported, + get_quantization_recipe_from_name_string, + hf_login_if_available, +) + +hf_login_if_available() IMAGE_H = 28 IMAGE_W = 28 From 2712bb95cb4a4f7d1f2b8b473a2240ac3d6e7e58 Mon Sep 17 00:00:00 2001 From: Kunlun Li <94586211+kunlunl@users.noreply.github.com> Date: Wed, 22 Oct 2025 07:04:04 +0800 Subject: [PATCH 074/141] Add post-processing API for FP8 primary weights to support CUDA Graph (#2266) * Add post-processing API for FP8 primary weights to support CUDA Graph Signed-off-by: kunlunl * Add post-processing support for plain pytorch tensors Signed-off-by: kunlunl * Update type hint Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: kunlunl Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../run_cast_master_weights_to_fp8.py | 46 +++++++++++-------- transformer_engine/pytorch/tensor/utils.py | 43 +++++++++-------- 2 files changed, 50 insertions(+), 39 deletions(-) diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py index 976991633..2f11a24ee 100644 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py @@ -27,7 +27,7 @@ Float8BlockwiseQTensor, ) from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 -from transformer_engine.pytorch.tensor.utils import replace_raw_data +from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data def _get_raw_data(quantized_tensor): @@ -203,12 +203,15 @@ def step(self): # ----------------------------------------------------------------------------------------- # Step 7: Copy the gathered weights from weight buffer to the actual weights # ----------------------------------------------------------------------------------------- + quantized_weights = [] for weight, offset in zip(self.weights, self.offsets[:-1]): start = offset end = offset + weight.numel() if isinstance(weight, QuantizedTensor): + quantized_weights.append(weight) weight = _get_raw_data(weight) weight.view(-1).data.copy_(self.weight_buffer[start:end]) + post_all_gather_processing(quantized_weights) class MiniOptimizer: @@ -252,10 +255,6 @@ def __init__(self, weights, lr, dp_group): self.dp_group = dp_group # Flatten the weights and pad to align with world size - raw_data_list = [ - _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1) - for w in weights - ] if isinstance(weights[0], QuantizedTensor): raw_data_list = [_get_raw_data(w).view(-1) for w in weights] else: @@ -264,7 +263,9 @@ def __init__(self, weights, lr, dp_group): # Split flattened weights into shards self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] - self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) + self.local_main_grad_shard = torch.zeros_like( + self.local_weight_shard, dtype=torch.float32, device="cuda" + ) shard_size = self.flatten_weight.size(0) // world_size # Map original tensors to flattened indices @@ -341,9 +342,8 @@ def _flatten_tensors_with_pad(self, tensors): padding_needed = (world_size - original_length % world_size) % world_size if padding_needed > 0: - flatten_tensor = torch.cat( - [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] - ) + zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda") + flatten_tensor = torch.cat([flatten_tensor, zeros]) return flatten_tensor, original_length @@ -369,10 +369,10 @@ def step(self): main_grad_buffer, _ = self._flatten_tensors_with_pad( [weight.main_grad.view(-1) for weight in self.weights] ) - main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype) dist.reduce_scatter_tensor( self.local_main_grad_shard, main_grad_buffer, group=self.dp_group ) + self.local_main_grad_shard /= dist.get_world_size(self.dp_group) # Step 2: Update the master weights for weight, master_weight, (shard_start, shard_end) in zip( @@ -416,6 +416,11 @@ def step(self): dist.all_gather_into_tensor( self.flatten_weight, self.local_weight_shard, group=self.dp_group ) + quantized_weights = [] + for weight in self.weights: + if isinstance(weight, QuantizedTensor): + quantized_weights.append(weight) + post_all_gather_processing(quantized_weights) def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): @@ -435,7 +440,7 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): linear_kwargs = { "params_dtype": torch.bfloat16, "bias": False, - "fuse_wgrad_accumulation": False, + "fuse_wgrad_accumulation": True, } # Create model with FP8 weights @@ -503,14 +508,9 @@ def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) - print( - f"✅ Successfully validated FSDP {NUM_STEPS} training steps with" - f" {quantization} quantization" - ) - -def _test_zero_1(dp_group): - """Make sure the implementation of zero-1 optimizer is correct""" +def _test_mini_optimizer(dp_group): + """Make sure the implementation of MiniZero_1 and MiniFSDP is correct""" rank = dist.get_rank(dp_group) world_size = dist.get_world_size(dp_group) @@ -525,13 +525,15 @@ def _test_zero_1(dp_group): weights_1 = weights weights_2 = [weight.clone() for weight in weights] + weights_3 = [weight.clone() for weight in weights] lr = 1.0 optimizer_1 = MiniZero_1(weights_1, lr, dp_group) optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) + optimizer_3 = MiniFSDP(weights_3, lr, dp_group) for _ in range(100): - for w1, w2 in zip(weights_1, weights_2): + for w1, w2, w3 in zip(weights_1, weights_2, weights_3): main_grads = [ torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) ] @@ -539,12 +541,16 @@ def _test_zero_1(dp_group): main_grad = main_grads[rank] w1.main_grad = main_grad w2.main_grad = main_grad + w3.main_grad = main_grad optimizer_1.step() optimizer_2.step() + optimizer_3.step() for w1, w2 in zip(weights_1, weights_2): torch.testing.assert_close(w1, w2, atol=0, rtol=0) + for w1, w3 in zip(weights_1, weights_3): + torch.testing.assert_close(w1, w3, atol=0, rtol=0) def quantization_recipe(quantization) -> Recipe: @@ -671,7 +677,7 @@ def main(argv=None, namespace=None): args = parser.parse_args(argv, namespace) dp_group = dist.new_group(backend="nccl") - _test_zero_1(dp_group) + _test_mini_optimizer(dp_group) _test_cast_master_weights_to_fp8(args.quantization, dp_group) _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index cc0249401..72c465edb 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -5,7 +5,7 @@ """Helper functions for using fp8 tensors as weights""" import os -from typing import Optional, Union +from typing import Optional, List, Union import torch import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv @@ -15,6 +15,7 @@ from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier +from ..utils import is_non_tn_fp8_gemm_supported def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -159,12 +160,6 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo amaxes, scales, scale_invs = [], [], [] for model_weight, master_weight, start_offset, shard_model_weight_raw in params: - # Reset transpose cache for all model weights. - # We cannot create transpose cache here because users (like megatron) may want to overlap - # the all-gather of model weights and forward process, so the model weight is not updated - # currently. - model_weight._reset_caches() - quantizer = model_weight._get_quantizer() amaxes.append(quantizer.amax.view(1)) @@ -302,12 +297,6 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( params, scales ): - # Reset transpose cache for all model weights. - # We cannot create transpose cache here because users (like megatron) may want to overlap - # the all-gather of model weights and forward process, so the model weight is not updated - # currently. - model_weight._reset_caches() - # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. if master_weight is None: @@ -432,12 +421,6 @@ def _cast_master_weights_to_fp8_blockwise_scaling( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( params, scales ): - # Clear columnwise data for all model weights. - # We cannot create columnwise data here because users (like megatron) may want to overlap - # the all-gather of model weights and forward process, so the model weight is not updated - # at this moment. - model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) - # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. if master_weight is None: @@ -454,6 +437,28 @@ def _cast_master_weights_to_fp8_blockwise_scaling( ) +def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): + """ + Post-processing after all-gather for weights in distributed optimizer. + - Float8Tensor: may need to create a transposed view to match backend GEMM. + - Float8BlockwiseQTensor: create column-wise storage. + - Plain pytorch tensor: noop. + """ + if not isinstance(model_weights, list): + model_weights = [model_weights] + for model_weight in model_weights: + if isinstance(model_weight, Float8Tensor): + # Delayed scaling and per-tensor current scaling: if backend does not support + # non-transposed FP8 GEMM, pre-create the transpose. + if not is_non_tn_fp8_gemm_supported(): + model_weight._create_transpose() + elif isinstance(model_weight, Float8BlockwiseQTensor): + # Blockwise scaling: create column-wise storage. + model_weight._create_columnwise() + elif isinstance(model_weight, QuantizedTensor): + raise ValueError(f"post_processing for {type(model_weight)} is not supported") + + def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: """Check if an environment or object is using experimental Kitchen middleware. From ce2e8bd12edfe10647bec8f54fedc394d6287b58 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Wed, 22 Oct 2025 14:58:27 +0200 Subject: [PATCH 075/141] [PyTorch] Decouple python quantization classes and refactor custom quantization (#2276) * rename experimental -> custom_recipes Signed-off-by: Evgeny * Decouple python base classes (api) Signed-off-by: Evgeny * update test_custom_recipe Signed-off-by: Evgeny * Rename experimental -> custom Signed-off-by: Evgeny * Minor Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix import Signed-off-by: Evgeny * Update tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Evgeny Tsykunov * Update tests/pytorch/test_custom_recipe.py Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Evgeny Tsykunov * quantization_base -> quantized_tensor rename Signed-off-by: Evgeny --------- Signed-off-by: Evgeny Signed-off-by: Evgeny Tsykunov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/attention/test_attention.py | 3 +- .../pytorch/distributed/run_numerics_exact.py | 6 +- .../test_fusible_ops_with_userbuffers.py | 1 + .../distributed/test_numerics_exact.py | 2 +- tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 4 +- .../pytorch/nvfp4/test_nvfp4_module_exact.py | 4 +- .../nvfp4/test_nvfp4_quantize_exact.py | 4 +- .../nvfp4/test_nvfp4_rht_quantize_exact.py | 6 +- tests/pytorch/test_custom_recipe.py | 42 ++++++++++ .../debug/pytorch/debug_quantization.py | 2 +- transformer_engine/pytorch/__init__.py | 10 +-- .../dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 4 +- .../pytorch/cpp_extensions/fused_attn.py | 2 +- .../pytorch/cpp_extensions/gemm.py | 12 +-- transformer_engine/pytorch/cpu_offload.py | 2 +- .../__init__.py | 0 .../{experimental => custom_recipes}/gemm.py | 12 +-- .../quantization.py | 0 .../quantization_nvfp4.py | 14 ++-- .../{experimental => custom_recipes}/utils.py | 0 transformer_engine/pytorch/distributed.py | 2 +- transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 12 +-- .../pytorch/module/layernorm_mlp.py | 12 +-- transformer_engine/pytorch/module/linear.py | 14 ++-- transformer_engine/pytorch/ops/_common.py | 2 +- .../ops/fused/userbuffers_backward_linear.py | 2 +- .../ops/fused/userbuffers_forward_linear.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 2 +- transformer_engine/pytorch/permutation.py | 2 +- .../pytorch/{tensor => }/quantized_tensor.py | 76 ++--------------- transformer_engine/pytorch/tensor/__init__.py | 2 +- .../pytorch/tensor/_quantization_helpers.py | 84 +++++++++++++++++++ .../pytorch/tensor/float8_blockwise_tensor.py | 7 +- .../pytorch/tensor/float8_tensor.py | 7 +- .../pytorch/tensor/mxfp8_tensor.py | 7 +- .../pytorch/tensor/nvfp4_tensor.py | 3 +- .../float8_blockwise_tensor_storage.py | 4 +- .../tensor/storage/float8_tensor_storage.py | 4 +- .../tensor/storage/mxfp8_tensor_storage.py | 4 +- .../tensor/storage/nvfp4_tensor_storage.py | 3 +- transformer_engine/pytorch/tensor/utils.py | 19 ++--- transformer_engine/pytorch/utils.py | 2 +- 45 files changed, 227 insertions(+), 181 deletions(-) rename transformer_engine/pytorch/{experimental => custom_recipes}/__init__.py (100%) rename transformer_engine/pytorch/{experimental => custom_recipes}/gemm.py (90%) rename transformer_engine/pytorch/{experimental => custom_recipes}/quantization.py (100%) rename transformer_engine/pytorch/{experimental => custom_recipes}/quantization_nvfp4.py (98%) rename transformer_engine/pytorch/{experimental => custom_recipes}/utils.py (100%) rename transformer_engine/pytorch/{tensor => }/quantized_tensor.py (89%) create mode 100644 transformer_engine/pytorch/tensor/_quantization_helpers.py diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 7dc6caeb8..3150c06ab 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -45,7 +45,8 @@ ) from transformer_engine.pytorch.utils import get_cudnn_version import transformer_engine_torch as tex -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( + Quantizer, prepare_for_saving, restore_from_saved, ) diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index ccbc3259b..3605b3c70 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -22,8 +22,8 @@ ) from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE -from transformer_engine.pytorch.experimental import quantization_nvfp4 -from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import utils from run_layer_with_overlap import _compare_tensors @@ -486,7 +486,7 @@ def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): sequence_parallel (bool): Enable sequence parallelism if True. kwargs (dict): Additional arguments for the linear layer. - QUANTIZATION options: nvfp4 <=> experimental nvfp4 as a reference + QUANTIZATION options: nvfp4 <=> custom nvfp4 as a reference """ params_dtype = torch.bfloat16 use_bias = kwargs.get("bias", True) diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 24112cc9f..61c813b8f 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -34,6 +34,7 @@ Float8Tensor, ) + # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py index fd6ef65e0..72aa78664 100644 --- a/tests/pytorch/distributed/test_numerics_exact.py +++ b/tests/pytorch/distributed/test_numerics_exact.py @@ -14,7 +14,7 @@ Distributed numerics tests This numerical test aims for zero tolerance test for absolute confidence in numerics. - In the case of NVFP4, with the experimental NVFP4 quantization, we matched bitwise + In the case of NVFP4, with the custom NVFP4 quantization, we matched bitwise result with the native silicon. For distrbuted test cases, we can do the same by thing by comparing BF16 AG results with the low precision AG results at layer level. """ diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 77cfaaffe..6009643ff 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,8 +8,8 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef -from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index 44f222b9d..0292063ab 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -6,8 +6,8 @@ import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.experimental import quantization_nvfp4 -from transformer_engine.pytorch.experimental import utils +from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import utils recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 8c2444557..2467c7e2e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -7,10 +7,10 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.experimental import utils recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 6f2f846a3..904dfc2ea 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -12,10 +12,10 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.experimental.quantization_nvfp4 import NVFP4QuantizerRef -from transformer_engine.pytorch.experimental import utils +from transformer_engine.common.recipe import NVFP4BlockScaling import pytest import torch diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 516354a34..64f1c3d15 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -17,6 +17,48 @@ Float8CurrentScalingQuantizer, ) import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( + nvfp4_ref_rht_2d_quantizer_factory, +) + + +@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear"]) +def test_custom_recipe_sanity_modules_nvfp4(module_type): + """Test modules with NVFP4 custom recipe support""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + # Simple linear layer with dims divisible by 16 + in_features = 64 + out_features = 64 + batch = 32 + + if module_type == "Linear": + model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + elif module_type == "LayerNormLinear": + model = LayerNormLinear( + in_features, out_features, params_dtype=torch.bfloat16, bias=False + ).cuda() + else: # OpsLinear + model = te_ops.Linear( + in_features, out_features, device="cuda", dtype=torch.bfloat16, bias=False + ) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Use NVFP4 quantizer factory + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) + + # Execute with custom recipe + with autocast(enabled=True, recipe=custom_recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Basic sanity: gradients exist + assert inp.grad is not None @pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"]) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 185bf15d0..7f45a24e2 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -15,7 +15,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensor, Quantizer, QuantizedTensorStorage, diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 77c71b811..9d894a389 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -66,24 +66,24 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy -from transformer_engine.pytorch.tensor import Quantizer +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import Quantizer +from transformer_engine.pytorch.quantized_tensor import prepare_for_saving +from transformer_engine.pytorch.quantized_tensor import restore_from_saved from transformer_engine.pytorch.tensor import Float8Quantizer from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer from transformer_engine.pytorch.tensor import MXFP8Quantizer from transformer_engine.pytorch.tensor import Float8BlockQuantizer from transformer_engine.pytorch.tensor import NVFP4Quantizer -from transformer_engine.pytorch.tensor import QuantizedTensorStorage from transformer_engine.pytorch.tensor import Float8TensorStorage from transformer_engine.pytorch.tensor import MXFP8TensorStorage from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage from transformer_engine.pytorch.tensor import NVFP4TensorStorage -from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import Float8Tensor from transformer_engine.pytorch.tensor import MXFP8Tensor from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor import NVFP4Tensor -from transformer_engine.pytorch.tensor import prepare_for_saving -from transformer_engine.pytorch.tensor import restore_from_saved try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 6dfe0d31b..6c19d868a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -24,7 +24,7 @@ Float8Quantizer, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensorStorage, prepare_for_saving, restore_from_saved, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index a474cb809..e5ee8cc7d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -21,7 +21,7 @@ ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser from transformer_engine.pytorch.constants import ( dist_group_type, @@ -33,7 +33,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( prepare_for_saving, restore_from_saved, ) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 94a12c4a0..f80c001a1 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -15,7 +15,7 @@ NVTE_Softmax_Type, NVTE_Fused_Attn_Backend, ) -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer __all__ = [ diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index a45fafb68..dd0411298 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,10 +11,10 @@ from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from ..tensor.utils import is_experimental -from ..experimental.gemm import experimental_gemm +from ..tensor.utils import is_custom +from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ @@ -79,9 +79,9 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") - # If A or B are experimental tensors -> dispatch to quantizers's qgemm implementation - if is_experimental(A) or is_experimental(B): - return experimental_gemm( + # If A or B are custom tensors -> dispatch to quantizers's qgemm implementation + if is_custom(A) or is_custom(B): + return custom_gemm( A, B, workspace, diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 648b21eb4..6edc12620 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -10,7 +10,7 @@ import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .tensor.quantized_tensor import QuantizedTensorStorage +from .quantized_tensor import QuantizedTensorStorage from .tensor.float8_tensor import Float8Tensor __all__ = ["get_cpu_offload_context"] diff --git a/transformer_engine/pytorch/experimental/__init__.py b/transformer_engine/pytorch/custom_recipes/__init__.py similarity index 100% rename from transformer_engine/pytorch/experimental/__init__.py rename to transformer_engine/pytorch/custom_recipes/__init__.py diff --git a/transformer_engine/pytorch/experimental/gemm.py b/transformer_engine/pytorch/custom_recipes/gemm.py similarity index 90% rename from transformer_engine/pytorch/experimental/gemm.py rename to transformer_engine/pytorch/custom_recipes/gemm.py index 0bd740d85..cc98a8a57 100644 --- a/transformer_engine/pytorch/experimental/gemm.py +++ b/transformer_engine/pytorch/custom_recipes/gemm.py @@ -2,21 +2,21 @@ # # See LICENSE for license information. -"""GEMM API for experimental middleware between Transformer Engine and Kitchen.""" +"""GEMM API that enables custom GEMM logic for custom quantization recipes.""" from typing import Iterable, Optional import torch -from transformer_engine.pytorch.experimental.quantization import ( +from transformer_engine.pytorch.custom_recipes.quantization import ( MMParams, GEMMType, ) -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer -from transformer_engine.pytorch.tensor.utils import is_experimental +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.tensor.utils import is_custom -def experimental_gemm( +def custom_gemm( A: QuantizedTensorStorage, B: QuantizedTensorStorage, workspace: torch.Tensor, # pylint: disable=unused-argument @@ -32,7 +32,7 @@ def experimental_gemm( grad: bool = False, ) -> Iterable[Optional[torch.Tensor]]: """Dispatch GEMM to quantizer's qgemm method.""" - assert is_experimental(A) and is_experimental(B), "A and B must be experimental tensors" + assert is_custom(A) and is_custom(B), "A and B must be custom tensors" A, B = B, A diff --git a/transformer_engine/pytorch/experimental/quantization.py b/transformer_engine/pytorch/custom_recipes/quantization.py similarity index 100% rename from transformer_engine/pytorch/experimental/quantization.py rename to transformer_engine/pytorch/custom_recipes/quantization.py diff --git a/transformer_engine/pytorch/experimental/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py similarity index 98% rename from transformer_engine/pytorch/experimental/quantization_nvfp4.py rename to transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index fc50d0742..1ce9079eb 100644 --- a/transformer_engine/pytorch/experimental/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -9,9 +9,9 @@ import torch -from transformer_engine.pytorch.experimental import quantization -from transformer_engine.pytorch.experimental import utils -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.custom_recipes import quantization +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer def nvfp4_ref_rht_2d_quantizer_factory(role): @@ -229,8 +229,8 @@ class NVFP4TensorRef(QuantizedTensorStorage): _quantizer: Optional[Quantizer] = None @property - def experimental(self) -> bool: - """Flag to indicate this quantizer is using experimental Kitchen middleware.""" + def custom(self) -> bool: + """Flag to indicate this quantized tensor is custom.""" return True def prepare_for_saving( @@ -362,8 +362,8 @@ def __init__( self.with_random_sign_mask = with_random_sign_mask @property - def experimental(self) -> bool: - """Flag to indicate this quantizer is using experimental Kitchen middleware""" + def custom(self) -> bool: + """Flag to indicate this quantizer is custom.""" return True @staticmethod diff --git a/transformer_engine/pytorch/experimental/utils.py b/transformer_engine/pytorch/custom_recipes/utils.py similarity index 100% rename from transformer_engine/pytorch/experimental/utils.py rename to transformer_engine/pytorch/custom_recipes/utils.py diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 5ed73f678..8c14d5ab7 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -41,7 +41,7 @@ from .tensor.mxfp8_tensor import MXFP8Quantizer from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer +from .quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer from .tensor.storage.float8_tensor_storage import Float8TensorStorage from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index d16455b5b..7f571ce01 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -38,7 +38,7 @@ _fsdp_gather_tensors, ) from ..constants import dist_group_type -from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer +from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index a5bf21ee1..aae85e2ca 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -43,7 +43,7 @@ from ..cpu_offload import is_cpu_offload_enabled from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer -from ..tensor.quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensorStorage, Quantizer, prepare_for_saving, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 05f2e9cde..933c7cde5 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -16,7 +16,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version -from transformer_engine.pytorch.tensor.utils import is_experimental +from transformer_engine.pytorch.tensor.utils import is_custom from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -56,7 +56,7 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ._common import apply_normalization, noop_cat, WeightGradStore -from ..tensor.quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, Quantizer, @@ -194,13 +194,13 @@ def forward( # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. - experimental = is_experimental(input_quantizer) + custom = is_custom(input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not experimental # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() + and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) # Apply normalization @@ -246,8 +246,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = input_quantizer - # experimental recipe doesn't need to support quantized AG - if not with_quantized_norm and not experimental: + # custom recipe doesn't need to support quantized AG + if not with_quantized_norm and not custom: ln_out = quantizer(ln_out) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a2ddb970a..bae0f2825 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -17,7 +17,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version -from transformer_engine.pytorch.tensor.utils import is_experimental +from transformer_engine.pytorch.tensor.utils import is_custom from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -70,7 +70,7 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ..tensor.quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensorStorage, Quantizer, prepare_for_saving, @@ -268,13 +268,13 @@ def forward( # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm - experimental = is_experimental(fc1_input_quantizer) + custom = is_custom(fc1_input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered - and not experimental + and not custom ) # Apply normalization @@ -314,8 +314,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer - # experimental recipe doesn't need to support quantized AG - if not with_quantized_norm and not experimental: + # custom recipe doesn't need to support quantized AG + if not with_quantized_norm and not custom: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3069c21d9..ccb84e664 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -57,7 +57,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor.quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensor, QuantizedTensorStorage, Quantizer, @@ -66,7 +66,7 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.utils import is_experimental +from ..tensor.utils import is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -153,8 +153,8 @@ def forward( ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG - # experimental recipe check - experimental = is_experimental(input_quantizer) or is_experimental(weight_quantizer) + # custom recipe check + custom = is_custom(input_quantizer) or is_custom(weight_quantizer) # ------------------------------------------------------ # Prepare input tensor @@ -178,7 +178,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorStorage) and not experimental: + if not isinstance(inputmat, QuantizedTensorStorage) and not custom: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -448,7 +448,7 @@ def forward( ctx.main_grad_func = lambda: weight.main_grad ctx.debug = debug - ctx.experimental = experimental + ctx.custom = custom ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -616,7 +616,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass - elif ctx.debug or ctx.experimental: + elif ctx.debug or ctx.custom: # Debug quantizer will be applied immediately before wgrad GEMM pass else: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 52ca84b5d..a07ffea43 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -13,7 +13,7 @@ from .. import torch_version from ..quantization import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor -from ..tensor.quantized_tensor import QuantizedTensorStorage +from ..quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index d95b2298f..fd1820d15 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -21,7 +21,7 @@ get_ub, get_workspace, ) -from ...tensor.quantized_tensor import Quantizer +from ...quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ..basic import BasicLinear, Bias, ReduceScatter diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index e20de53da..057eb576d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -21,7 +21,7 @@ get_workspace, _2X_ACC_FPROP, ) -from ...tensor.quantized_tensor import Quantizer +from ...quantized_tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_dequantize, is_quantized_tensor diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 8ae112022..6026a40b6 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -28,7 +28,7 @@ fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( prepare_for_saving, restore_from_saved, ) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index ea3e67a57..f73bc9a96 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py similarity index 89% rename from transformer_engine/pytorch/tensor/quantized_tensor.py rename to transformer_engine/pytorch/quantized_tensor.py index a524d5c8d..15f5b6bd5 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -2,10 +2,10 @@ # # See LICENSE for license information. -"""Tensor with quantized data""" +"""Pure Python base classes for quantization.""" from __future__ import annotations -from typing import Callable, Optional, Tuple, Iterable, Any, Dict, Union +from typing import Optional, Tuple, Iterable, Any, Dict, Union import abc import copy import warnings @@ -14,6 +14,11 @@ from torch.utils._pytree import tree_map from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor._quantization_helpers import ( + _QuantizeFunc, + _IdentityFunc, + _stride_from_shape, +) class QuantizedTensorStorage: @@ -310,73 +315,6 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-a return True -class _QuantizeFunc(torch.autograd.Function): - """Quantize tensor""" - - @staticmethod - def forward( - _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: torch.Tensor, - quantize_impl: Callable, - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - return quantize_impl(tensor) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, # unused - grad: torch.Tensor, - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if constructor kwargs are not provided - if init_kwargs is None: - return tensor.detach() - - # Construct new tensor if constructor kwargs are provided - ctx.input_dtype = tensor.dtype - kwargs = tensor.get_metadata() - for key, val in init_kwargs.items(): - kwargs[key] = val - return type(tensor)(tensor.shape, tensor.dtype, **kwargs) - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - grad_input = grad_output - if grad_input.dtype == ctx.input_dtype: - grad_input = grad_input.detach() - else: - grad_input = grad_input.to(ctx.input_dtype) - return grad_input, None - - -def _stride_from_shape(shape: list[int]): - if len(shape) == 0: - return [] - rstride = [1] - for d in reversed(shape[1:]): - rstride.append(rstride[-1] * d) - return list(reversed(rstride)) - - class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 7689e2019..ada624a90 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,7 +6,7 @@ import torch -from .quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensorStorage, QuantizedTensor, Quantizer, diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py new file mode 100644 index 000000000..2214edbff --- /dev/null +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -0,0 +1,84 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Private helper functions and classes for quantized tensor implementations. + +This module contains internal autograd functions and utilities that support +the quantization machinery. +""" + +from __future__ import annotations +from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + +class _QuantizeFunc(torch.autograd.Function): + """Quantize tensor""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: torch.Tensor, + quantize_impl: Callable, + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + return quantize_impl(tensor) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if constructor kwargs are not provided + if init_kwargs is None: + return tensor.detach() + + # Construct new tensor if constructor kwargs are provided + ctx.input_dtype = tensor.dtype + kwargs = tensor.get_metadata() + for key, val in init_kwargs.items(): + kwargs[key] = val + return type(tensor)(tensor.shape, tensor.dtype, **kwargs) + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + grad_input = grad_output + if grad_input.dtype == ctx.input_dtype: + grad_input = grad_input.detach() + else: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input, None + + +def _stride_from_shape(shape: list[int]): + """Calculate stride from shape for contiguous tensors""" + if len(shape) == 0: + return [] + rstride = [1] + for d in reversed(shape[1:]): + rstride.append(rstride[-1] * d) + return list(reversed(rstride)) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 48762499b..8054374c8 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -14,11 +14,8 @@ from transformer_engine.common.recipe import Float8BlockScaling, Recipe from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from .quantized_tensor import ( - QuantizedTensor, - Quantizer, - _IdentityFunc, -) +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a4e68e53b..de112bb3f 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -14,11 +14,8 @@ from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func -from .quantized_tensor import ( - QuantizedTensor, - Quantizer, - _IdentityFunc, -) +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 700de24c4..5ef5708fd 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -17,11 +17,8 @@ from ..utils import devices_match, round_up_to_nearest_multiple from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func -from .quantized_tensor import ( - QuantizedTensor, - Quantizer, - _IdentityFunc, -) +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 5e2eeed72..7a5f8858f 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -22,7 +22,8 @@ ) from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc aten = torch.ops.aten diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 9040ea3a4..c2d5e8b3f 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -13,12 +13,10 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import Float8BlockScaleTensorFormat -from ..quantized_tensor import QuantizedTensorStorage +from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...constants import TE_DType_To_Torch -from ..quantized_tensor import Quantizer - from ...utils import _empty_tensor diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index b9533edb6..a31f6a379 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -12,12 +12,10 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorStorage +from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...constants import TE_DType as torch_to_transformer_engine_dtype -from ..quantized_tensor import Quantizer - from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index c1f30146c..2cca0829d 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -13,12 +13,10 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorStorage +from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...constants import TE_DType as torch_to_transformer_engine_dtype -from ..quantized_tensor import Quantizer - from ...utils import _empty_tensor diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 350103f7c..67543a8e2 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -16,10 +16,9 @@ # import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorStorage +from ...quantized_tensor import QuantizedTensorStorage, Quantizer # from ...constants import TE_DType as torch_to_transformer_engine_dtype -from ..quantized_tensor import Quantizer from ...utils import _empty_tensor diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 72c465edb..8354823b3 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -4,13 +4,13 @@ """Helper functions for using fp8 tensors as weights""" -import os -from typing import Optional, List, Union +from typing import Optional, Union, List import torch + import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage +from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer @@ -459,18 +459,13 @@ def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Ten raise ValueError(f"post_processing for {type(model_weight)} is not supported") -def is_experimental(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: - """Check if an environment or object is using experimental Kitchen middleware. +def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: + """Check if an object is custom. Returns False if x is a torch.Tensor. """ - # Detect if the environment is experimental - if x is None: - return int(os.getenv("QAT_PARAMS", "0")) > 0 - - # Detect if the object is experimental - if isinstance(x, torch.Tensor): + if x is None or isinstance(x, torch.Tensor): return False if not isinstance(x, (Quantizer, QuantizedTensorStorage)): raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance") - return hasattr(x, "experimental") and x.experimental + return hasattr(x, "custom") and x.custom diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 2be0aed4a..90c628996 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -12,7 +12,7 @@ import torch from . import torch_version -from .tensor.quantized_tensor import Quantizer +from .quantized_tensor import Quantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor From 818b30cc4b07bcac955b17a6a12ca9708d7f0a7e Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 22 Oct 2025 08:51:36 -0700 Subject: [PATCH 076/141] [JAX] NVFP4 recipe with option to enable/disable SR, RHT, and 2D quantization (#2270) * [JAX] Support recipe flags for disabling SR, RHT, and 2D quantization Signed-off-by: Jeremy Berchtold * lint Signed-off-by: Jeremy Berchtold * Fix issue with SR state being erased due to pytree handling of NVFP4Quantizer Signed-off-by: Jeremy Berchtold * Add test for SR state preservation across VJP boundaries Signed-off-by: Jeremy Berchtold * Fix sharding of SR rng state Signed-off-by: Jeremy Berchtold * lint Signed-off-by: Jeremy Berchtold * update tolerances slightly now that SR is enabled Signed-off-by: Jeremy Berchtold * lint Signed-off-by: Jeremy Berchtold * Use hashlib for deterministic hashes across runs for SR Signed-off-by: Jeremy Berchtold * rename uses_rht on scaled tensors to has_applied_rht Signed-off-by: Jeremy Berchtold * add assert Signed-off-by: Jeremy Berchtold * Move decision of whether to use RHT into helper.py and add dedicated RHT tests Signed-off-by: Jeremy Berchtold * lint Signed-off-by: Jeremy Berchtold * fix use_rht attr usage Signed-off-by: Jeremy Berchtold * fix pure-jax rht usage criteria Signed-off-by: Jeremy Berchtold * Adjust tolerances after rebase Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- .../encoder/test_multiprocessing_encoder.py | 4 +- .../jax/encoder/test_single_gpu_encoder.py | 2 +- tests/jax/test_custom_call_compute.py | 155 ++++++++++++------ tests/jax/test_helper.py | 82 ++++++++- transformer_engine/jax/cpp_extensions/gemm.py | 16 +- .../jax/cpp_extensions/quantization.py | 40 +++-- .../jax/quantize/dequantizer.py | 11 +- transformer_engine/jax/quantize/hadamard.py | 26 --- transformer_engine/jax/quantize/helper.py | 90 +++++++--- transformer_engine/jax/quantize/metadata.py | 20 +++ transformer_engine/jax/quantize/quantizer.py | 34 +++- transformer_engine/jax/quantize/tensor.py | 25 +++ transformer_engine/jax/sharding.py | 13 ++ 13 files changed, 382 insertions(+), 136 deletions(-) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index c2e97029b..9605adf77 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -672,7 +672,7 @@ def test_te_mxfp8(self): def test_te_nvfp4(self): """Test Transformer Engine with NVFP4""" result = self.exec(True, "NVFP4BlockScaling") - assert result[0] < 0.451 and result[1] > 0.79 + assert result[0] < 0.451 and result[1] > 0.788 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): @@ -710,7 +710,7 @@ def test_te_mxfp8_shardy(self): def test_te_nvfp4_shardy(self): """Test Transformer Engine with NVFP4""" result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) - assert result[0] < 0.451 and result[1] > 0.79 + assert result[0] < 0.451 and result[1] > 0.788 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 1c62de7fa..81f2d6c74 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -390,7 +390,7 @@ def test_te_nvfp4(self): self.args.use_fp8 = True self.args.fp8_recipe = "NVFP4BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.476 and actual[1] > 0.775 + assert actual[0] < 0.477 and actual[1] > 0.769 if __name__ == "__main__": diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 2934e48df..1217ebf65 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,7 +40,6 @@ QuantizerFactory, QuantizeLayout, noop_quantizer_set, - should_use_rht, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -685,21 +684,14 @@ class TestQuantize: Purely quantization related tests that will always test on a wider set of types and shapes """ - def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): - """Temporary hack to skip unsupported FP4 cases until we implement them""" + def _skip_unsupported_dtypes(self, q_dtype, scaling_mode): + """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes.""" if q_dtype not in scaling_mode.get_compatible_q_dtypes(): pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") return - # HACK: FIXME TODO(jberchtold) - row = reduce(operator.mul, input_shape[flatten_axis:], 1) - col = reduce(operator.mul, input_shape[:flatten_axis], 1) - will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout) - if will_use_rht and (row % 64 != 0 or col % 128 != 0): - pytest.skip("Unfused RHT is not supported currently, skipping") - def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): - self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) @@ -780,22 +772,8 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt assert_dequantized_scaled_tensor(scaled_tensor, x) def _should_use_precise_comparison( - self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis ): - # TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values. - RHT_SLIGHT_MISMATCH_SHAPES = [ - ((32, 256, 128), -1), - ((64, 32, 32, 256), -1), - ((8192, 2, 4096), -2), - ] - - if ( - should_use_rht(scaling_mode, q_layout=q_layout) - and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES - ): - # TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes - return False - if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation return False @@ -805,7 +783,7 @@ def _should_use_precise_comparison( def test_quantize_bitwise( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): - self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) @@ -816,28 +794,20 @@ def test_quantize_bitwise( jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - try: - te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) - except AssertionError as e: - if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16: - error_message = e.args[0] - if "RHT requires input to be bfloat16" in error_message: - # Successfully caught the expected error, early return from the test - return - raise e + te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors( te_output, jax_output, precise_comparison=self._should_use_precise_comparison( - in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis ), ) def test_quantize_bitwise_jitted( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): - self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) + self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) @@ -851,21 +821,13 @@ def test_quantize_bitwise_jitted( jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) - try: - te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) - except AssertionError as e: - if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16: - error_message = e.args[0] - if "RHT requires input to be bfloat16" in error_message: - # Successfully caught the expected error, early return from the test - return - raise e + te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) assert_bitwise_scaled_tensors( te_output, jax_output, precise_comparison=self._should_use_precise_comparison( - in_dtype, scaling_mode, q_layout, input_shape, flatten_axis + in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis ), ) @@ -985,12 +947,6 @@ def _test_sr( def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" - # HACK: FIXME TODO(jberchtold) - row = reduce(operator.mul, input_shape[flatten_axis:], 1) - col = reduce(operator.mul, input_shape[:flatten_axis], 1) - will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout) - if will_use_rht and (row % 64 != 0 or col % 128 != 0): - pytest.skip("Unfused RHT is not supported currently, skipping") key = jax.random.PRNGKey(0) inputs = jax.random.uniform(key, input_shape, in_dtype) @@ -1007,6 +963,97 @@ def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) +@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) +@pytest_parametrize_wrapper( + "scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING] +) +class TestRandomizedHadamardTransform: + + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE] + ) + @pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)]) + def test_rht_quantize_bitwise_jitted( + self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ): + key = jax.random.PRNGKey(0) + inputs = jax.random.uniform(key, input_shape, in_dtype) + + te_quantizer, jax_quantizer = QuantizerFactory.create( + n_quantizers=2, + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + use_rht=True, + ) + + jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) + te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) + + jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis) + + te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis) + + assert_bitwise_scaled_tensors(te_output, jax_output) + + def _ref_gemm_with_jnp_dot(self, a, b, data_layout): + if data_layout[0] == "T": + a = jnp.swapaxes(a, -1, -2) + if data_layout[1] == "T": + b = jnp.swapaxes(b, -1, -2) + return jnp.dot(a, b) + + def _generate_gemm_input(self, m, n, k, data_layout): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + x = jax.random.uniform( + subkeys[0], + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), + dtype=jnp.bfloat16, + ) / jnp.sqrt(k) + w = jax.random.uniform( + subkeys[1], + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), + dtype=jnp.bfloat16, + ) / jnp.sqrt(n) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + + return (x, w, contracting_dims) + + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + # We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently + @pytest_parametrize_wrapper("data_layout", ["TN", "NT"]) + @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) + def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm): + key = jax.random.PRNGKey(0) + + lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + lhs_quantizer = QuantizerFactory.create( + scaling_mode=lhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + use_rht=True, + ) + rhs_quantizer = QuantizerFactory.create( + scaling_mode=rhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + use_rht=True, + ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + ) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) + + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index ca804625c..fc88b7ef7 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -3,11 +3,13 @@ # See LICENSE for license information. import unittest +from functools import partial import flax import jax import jax.numpy as jnp import numpy as np +from flax import linen as nn from utils import assert_allclose from transformer_engine.common.recipe import ( @@ -24,15 +26,51 @@ ScalingMode, update_collections, TensorSource, + QuantizerFactory, + QuantizeLayout, ) from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.sharding import MeshResource, global_mesh_resource +from transformer_engine.jax.flax.module import TransformerEngineBase is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) +def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): + """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries.""" + + # Define a function with a custom VJP (vector-Jacobian product) + @partial(jax.custom_vjp, nondiff_argnums=(1,)) + def quantizer_check(inner_quantizer_set, assertion_func, x): + return quantizer_check_fwd(inner_quantizer_set, assertion_func, x) + + def quantizer_check_fwd(inner_quantizer_set, assertion_func, x): + assertion_func(inner_quantizer_set.x, TensorSource.X) + assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL) + assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD) + return x + + def quantizer_check_bwd(ctx, g): + return (g,) + + quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd) + return quantizer_check(outer_quantizer_set, assertion_func, x) + + +class TestModule(TransformerEngineBase): + """A simple module to test quantizer creation and reconstruction across VJP boundaries.""" + + # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None + assertion_func: callable + + @nn.compact + def __call__(self, x): + quantizer_set = self.generate_quantizer_set() + return quantizer_check_vjp(quantizer_set, self.assertion_func, x) + + class TestHelper(unittest.TestCase): @unittest.skipIf(not is_fp8_supported, reason=reason) @@ -89,12 +127,43 @@ def _compare_nvfp4_scaling(self, test): for tensor_source in TensorSource: target_scaling_mode = ( ScalingMode.NVFP4_2D_SCALING - if tensor_source == TensorSource.KERNEL + if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL else ScalingMode.NVFP4_1D_SCALING ) self.assertEqual( get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode ) + self.assertEqual( + get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding + ) + self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht) + self.assertEqual( + get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization + ) + + def _compare_nvfp4_scaling_quantizers(self, test): + """Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries.""" + + def assertion_func(quantizer, tensor_source): + if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD: + self.assertIsNone(quantizer.stochastic_rounding_rng_state) + else: + self.assertIsNotNone(quantizer.stochastic_rounding_rng_state) + + expected_rht = ( + quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE} + and not test.disable_rht + ) + self.assertEqual(quantizer.use_rht, expected_rht) + + x = jnp.ones((), dtype=jnp.float32) + test_module = TestModule(assertion_func=assertion_func) + param_key, sr_key = jax.random.split(jax.random.PRNGKey(0)) + rngs = {"params": param_key, "sr_rng": sr_key} + variables = test_module.init(rngs, x) + + jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs) @unittest.skipIf(not is_fp8_supported, reason=reason) def test_autocast_delayed_scaling(self): @@ -171,5 +240,16 @@ def test_autocast_nvfp4_block_scaling(self): with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_nvfp4_scaling(bs) + self._compare_nvfp4_scaling_quantizers(bs) + + bs = NVFP4BlockScaling( + disable_stochastic_rounding=True, + disable_rht=True, + disable_2d_quantization=True, + ) + with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): + self.assertTrue(get_quantize_config().is_fp8_enabled()) + self._compare_nvfp4_scaling(bs) + self._compare_nvfp4_scaling_quantizers(bs) self._check_default_state() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index b37c4bd84..778f77c0d 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -44,7 +44,6 @@ noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, - should_use_rht, ) from .misc import get_padded_spec, is_all_reduce_in_float32 from ..sharding import ( @@ -169,16 +168,13 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x) - def uses_rht(q: AbstractBaseTensor) -> bool: - return isinstance(q, ScaledTensor1x) and should_use_rht( - q.scaling_mode, is_colwise=q.is_colwise - ) + def has_rht_applied(q: AbstractBaseTensor) -> bool: + return isinstance(q, ScaledTensor1x) and q.has_rht_applied - # TODO(jberchtold): Move RHT usage check to a bool flag on the ScaledTensor class - assert uses_rht(lhs_q) == uses_rht(rhs_q), ( - "With NVFP4_1D_SCALING, if one operand is colwise quantized, the other must be colwise" - " quantized as well. This is to ensure the RHT is applied to both and will cancel out in" - " the GEMM." + assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), ( + "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized" + " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the" + " GEMM." ) return lhs_q, rhs_q diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index b3f1e60f9..67c505bc9 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -31,7 +31,7 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, - num_of_devices, + get_num_devices_in_mesh, ) from ..quantize import ( ScaledTensor2x, @@ -45,7 +45,6 @@ compute_scale_from_amax, NoScaleTensor, get_rht_matrix, - should_use_rht, ) @@ -108,17 +107,18 @@ def abstract( "sr_rng_state must be a uint32 array when stochastic_rounding is True but" f" received {sr_rng_state_aval}" ) - if is_outer: + if is_outer and get_num_devices_in_mesh() > 1: assert ( - sr_rng_state_aval.shape[0] == num_of_devices() + sr_rng_state_aval.shape[0] == get_num_devices_in_mesh() and sr_rng_state_aval.shape[1] == 4 ), ( "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" f" True and is_outer is True but received {sr_rng_state_aval.shape}" ) else: - assert sr_rng_state_aval.shape == (4,), ( - "Sharded sr_rng_state must be of shape (4,) per device when" + # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state. + assert sr_rng_state_aval.size >= 4, ( + "Sharded sr_rng_state must have at least 4 elements per device when" f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" ) @@ -552,8 +552,13 @@ def partition( desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) - # TODO(jberchtold): Assert the sr_rng state is sharded along all mesh axes - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + arg_shardings = list(arg_i.sharding for arg_i in arg_infos) + arg_shardings[3] = NamedSharding( + mesh, + PartitionSpec(tuple(x for x in x_spec if x is not None), None), + desc="BaseDBiasQuantizePrimitive.sr_rng_state", + ) + arg_shardings = tuple(arg_shardings) out_shardings = ( out_sharding, colwise_out_sharding, @@ -564,6 +569,9 @@ def partition( ) def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): + if sr_rng_state.size > 4: + # See comment in abstract method for explanation of why we cannot assert exact shape + sr_rng_state = sr_rng_state.flatten()[:4] ( local_x, local_colwise_x, @@ -754,9 +762,10 @@ def _quantize_dbias_impl( # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - is_unsupported = ( - quantizer.q_layout == QuantizeLayout.COLWISE - and quantizer.scaling_mode != ScalingMode.NVFP4_1D_SCALING + is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not ( + quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + and hasattr(quantizer, "use_rht") + and quantizer.use_rht ) if is_unsupported or not PrimitiveClass.enabled(): if is_dbias: @@ -792,7 +801,7 @@ def _quantize_dbias_impl( rht_matrix = jnp.empty((1, 1), jnp.bfloat16) amax = x.amax - if should_use_rht(quantizer.scaling_mode, q_layout=quantizer.q_layout): + if hasattr(quantizer, "use_rht") and quantizer.use_rht: use_rht = True rht_matrix = get_rht_matrix() @@ -861,7 +870,11 @@ def _quantize_dbias_impl( x.data, scale, amax, - sr_rng_state if sr_rng_state is not None else jnp.empty((num_of_devices(), 1), jnp.uint32), + ( + sr_rng_state + if sr_rng_state is not None + else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32) + ), post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), rht_matrix, out_dtype=quantizer.q_dtype, @@ -902,6 +915,7 @@ def _quantize_dbias_impl( q_layout=quantizer.q_layout, data_layout=quantizer.get_data_layout(), flatten_axis=flatten_axis, + colwise_has_rht_applied=use_rht, ) return out, dbias.astype(dq_dtype) diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index b4da6f3be..80ebc6b87 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -15,7 +15,7 @@ import jax.numpy as jnp from .scaling_modes import ScalingMode -from .hadamard import apply_rht, should_use_rht +from .hadamard import apply_rht __all__ = ["ScalingModeToDequantizerMap"] @@ -171,7 +171,9 @@ class NVFP4Dequantizer(Dequantizer): """ @staticmethod - def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis): + def _dequantize_func( + data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis, has_rht_applied + ): """Dequantize a tensor using block scaling. Args: @@ -182,6 +184,7 @@ def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, scaling_mode: The scaling mode used for quantization is_colwise: Whether the scaling is column-wise flatten_axis: The axis along which the tensor could be flattened to 2D + has_rht_applied: Whether the quantization has RHT applied and we need to apply the inverse RHT to dequantize Returns: The dequantized tensor @@ -223,8 +226,7 @@ def _dequantize_func(data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape) # Apply inverse of RHT if needed - use_rht = should_use_rht(scaling_mode, is_colwise=is_colwise) - if use_rht: + if has_rht_applied: out = apply_rht(out, inverse=True) return out @@ -247,6 +249,7 @@ def dequantize(scaled_tensor): scaled_tensor.scaling_mode, scaled_tensor.is_colwise, scaled_tensor.flatten_axis, + scaled_tensor.has_rht_applied, ) diff --git a/transformer_engine/jax/quantize/hadamard.py b/transformer_engine/jax/quantize/hadamard.py index c0b74ef75..5f6f0ec2b 100644 --- a/transformer_engine/jax/quantize/hadamard.py +++ b/transformer_engine/jax/quantize/hadamard.py @@ -4,32 +4,6 @@ """Randomized Hadamard Transform (RHT) utilities for JAX.""" import jax.numpy as jnp -from .scaling_modes import ScalingMode - - -def should_use_rht(scaling_mode, is_colwise=None, q_layout=None) -> bool: - """Determine if RHT (Randomized Hadamard Transform) should be used. - - Args: - scaling_mode: The scaling mode of the tensor. - is_colwise: Whether the tensor is column-wise. Only one of is_colwise or q_layout should be provided. - q_layout: The quantization layout of the tensor. Only one of is_colwise or q_layout should be provided. - - Returns: - bool: True if RHT should be used, False otherwise. - """ - # Delayed import to avoid circular dependencies - from .quantizer import QuantizeLayout - - assert (is_colwise is None) != ( - q_layout is None - ), "Exactly one of is_colwise or q_layout must be provided." - - if q_layout is not None: - is_colwise = q_layout in {QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE} - - return scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise - def get_wgrad_sign_vector() -> list[int]: """Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization.""" diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 06c67b62e..e8b33c1d1 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -12,6 +12,7 @@ from contextlib import contextmanager from dataclasses import dataclass from enum import Enum +import hashlib from typing import Optional, Tuple, Dict, Union, Sequence, Type, List from functools import reduce, lru_cache import operator @@ -35,7 +36,7 @@ from transformer_engine.jax.sharding import ( global_shard_guard, MeshResource, - num_of_devices, + get_num_devices_in_mesh, get_all_mesh_axes, with_sharding_constraint, ) @@ -561,29 +562,87 @@ def get_quantize_flax_meta( return QuantizeMeta() +@dataclass class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for NVFP4 scaling recipe. This class provides specific initialization and finalization for NVFP4 scaling quantization mode. """ + DISABLE_STOCHASTIC_ROUNDING: bool = False + DISABLE_RHT: bool = False + DISABLE_2D_QUANTIZATION: bool = False + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: - """Initialize block scaling FP8 configuration. + """Initialize block scaling NVFP4 configuration. Args: - fp8_recipe: The FP8 recipe to use for initialization + fp8_recipe: The quantization recipe to use for initialization """ + assert isinstance(fp8_recipe, NVFP4BlockScaling) + self.INITIALIZED = True self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format) self.AMAX_HISTORY_LEN = 0 + self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding + self.DISABLE_RHT = fp8_recipe.disable_rht + self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" - if tensor_source == TensorSource.KERNEL: + if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL: return ScalingMode.NVFP4_2D_SCALING # for x and grad return ScalingMode.NVFP4_1D_SCALING + def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta: + """Create the quantization metadata for RHT if applicable.""" + # Imported here to prevent circular import + from transformer_engine.jax.quantize import QuantizeLayout + + use_rht = self.get_scaling_mode( + tensor_source + ) == ScalingMode.NVFP4_1D_SCALING and q_layout in { + QuantizeLayout.ROWWISE_COLWISE, + QuantizeLayout.COLWISE, + } + if self.DISABLE_RHT: + use_rht = False + return QuantizeMeta(use_rht=use_rht) + + def _make_stochastic_rounding_rng_state( + self, module, tensor_source: TensorSource, quantizer_name: str + ) -> jnp.ndarray: + """Create the stochastic rounding rng state if applicable.""" + if self.DISABLE_STOCHASTIC_ROUNDING: + return QuantizeMeta() + + if tensor_source != TensorSource.DGRAD: + # Only DGRAD uses stochastic rounding + return QuantizeMeta() + + sr_jax_rng = module.make_rng("sr_rng") + # Get a unique key for this quantizer + # Use hashlib to get a deterministic hash value for quantizer_name + quantizer_hash = ( + int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16) + % jnp.iinfo(jnp.int32).max + ) + sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash) + + # Generate 4 random uint32 values from the JAX PRNG key + shape = (4,) + if get_num_devices_in_mesh() > 1: + shape = (get_num_devices_in_mesh(), 4) + sr_jax_rng_state = jax.random.randint( + sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 + ).view(jnp.uint32) + sr_jax_rng_state = with_sharding_constraint( + sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None) + ) + return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state) + def get_quantize_flax_meta( self, module, @@ -603,27 +662,14 @@ def get_quantize_flax_meta( Returns: The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. """ - if tensor_source != TensorSource.DGRAD: - # Only DGRAD uses stochastic rounding - return QuantizeMeta() - - # TODO(jberchtold): This assumes SR is always enabled for NVFP4. Use flag from recipe to toggle it. - sr_jax_rng = module.make_rng("sr_rng") - # Get a unique key for this quantizer - sr_jax_rng = jax.jit(jax.random.fold_in)( - sr_jax_rng, hash(quantizer_name) % jnp.iinfo(jnp.int32).max - ) + # Imported here to prevent circular import + from transformer_engine.jax.quantize import QuantizeLayout - # Generate 4 random uint32 values from the JAX PRNG key - sr_jax_rng_state = jax.random.randint( - sr_jax_rng, (num_of_devices(), 4), 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 - ).view(jnp.uint32) - sr_jax_rng_state = with_sharding_constraint( - sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None) + return QuantizeMeta.merge( + self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source), + self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name), ) - return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state) - _QUANTIZE_CONFIG = NoOpQuantizeConfig() diff --git a/transformer_engine/jax/quantize/metadata.py b/transformer_engine/jax/quantize/metadata.py index 11a349ed7..a987643eb 100644 --- a/transformer_engine/jax/quantize/metadata.py +++ b/transformer_engine/jax/quantize/metadata.py @@ -26,6 +26,26 @@ class QuantizeMeta: """ + @staticmethod + def merge(a: "QuantizeMeta", b: "QuantizeMeta") -> "QuantizeMeta": + """Merge two QuantizeMeta instances. + + Args: + a (QuantizeMeta): The first QuantizeMeta instance. + b (QuantizeMeta): The second QuantizeMeta instance. + + Returns: + QuantizeMeta: A new QuantizeMeta instance with merged metadata. + """ + assert isinstance(a, QuantizeMeta) + assert isinstance(b, QuantizeMeta) + for key in b.get_kwargs_dictionary().keys(): + if key in a.get_kwargs_dictionary(): + assert ( + a.get_kwargs_dictionary()[key] == b.get_kwargs_dictionary()[key] + ), f"Conflict in merging QuantizeMeta: {key} has different values." + return QuantizeMeta(**{**a.get_kwargs_dictionary(), **b.get_kwargs_dictionary()}) + def __init__(self, **kwargs): self._kwargs = kwargs diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 7bc08f834..d138b58da 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -19,7 +19,7 @@ from transformer_engine.common import recipe from .scaling_modes import ScalingMode -from .hadamard import apply_rht, should_use_rht +from .hadamard import apply_rht from .tensor import ( ScaledTensor, ScaledTensor1x, @@ -590,11 +590,13 @@ class NVFP4Quantizer(Quantizer): q_layout: Quantization axis data_layout: Data layout string (default: "NT") stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled. + use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization. """ scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE data_layout: str = "NT" + use_rht: bool = False stochastic_rounding_rng_state: Optional[jnp.ndarray] = None def __post_init__(self): @@ -603,6 +605,30 @@ def __post_init__(self): ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn" assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes" + def tree_flatten(self): + """Flatten the quantizer for JAX tree operations. + + Returns: + Tuple of (children, aux_data) for tree operations + """ + children = (self.stochastic_rounding_rng_state,) + aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstruct a quantizer from its flattened representation. + + Args: + aux_data: Auxiliary data containing quantizer parameters + children: Unused children data + + Returns: + A reconstructed Quantizer instance + """ + stochastic_rounding_rng_state = children[0] + return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state) + def _apply_stochastic_rounding(self, x): assert ( self.stochastic_rounding_rng_state is not None @@ -688,8 +714,9 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> flatten_axis = x.ndim - flatten_axis x_shape = x.shape - if should_use_rht(self.scaling_mode, is_colwise=is_colwise): - # We only apply RHT for 1D colwise nvfp4 + # We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now. + use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING + if use_rht: x = apply_rht(x) dq_dtype = dq_dtype if dq_dtype is not None else x.dtype @@ -790,6 +817,7 @@ def repeat_to_shape(x, target_shape): scaling_mode=self.scaling_mode, dq_dtype=dq_dtype, flatten_axis=rowwise_flatten_axis, + has_rht_applied=use_rht, ) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 2d2d78190..6c358a044 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -175,6 +175,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): is_colwise: Whether the tensor uses column-wise quantization data_layout: The data_layout specification for the tensor flatten_axis: The quantization axis for the tensor + has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization """ scale_inv: jnp.ndarray @@ -184,6 +185,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): is_colwise: bool data_layout: str flatten_axis: int + has_rht_applied: bool def __post_init__(self): """Validates and adjusts the scale_inv shape after initialization. @@ -243,6 +245,7 @@ def tree_flatten(self): self.is_colwise, self.data_layout, self.flatten_axis, + self.has_rht_applied, ) return (children, aux_data) @@ -314,6 +317,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st is_colwise=self.is_colwise, data_layout=self.data_layout, flatten_axis=self.flatten_axis, + has_rht_applied=self.has_rht_applied, ) @@ -354,6 +358,7 @@ def __init__( self.group_sizes = group_sizes self.original_shape = original_shape self.group_axis = group_axis + # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, scale_inv=scale_inv, @@ -364,6 +369,7 @@ def __init__( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, + has_rht_applied=False, ) def __post_init__(self): @@ -515,6 +521,7 @@ def create_1x( group_sizes=None, original_shape=None, group_axis=0, + has_rht_applied=False, ): """Creates a single-scale quantized tensor. @@ -530,6 +537,7 @@ def create_1x( group_sizes: Array of ints containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) + has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided @@ -593,6 +601,7 @@ def create_1x( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, + has_rht_applied=has_rht_applied, ) @staticmethod @@ -610,6 +619,8 @@ def create_2x( group_sizes=None, original_shape=None, group_axis=0, + rowwise_has_rht_applied=False, + colwise_has_rht_applied=False, ): """Creates a double-scale quantized tensor. @@ -626,6 +637,8 @@ def create_2x( group_sizes: Array containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) + rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) Returns: A ScaledTensor2x instance @@ -648,6 +661,7 @@ def create_2x( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=rowwise_has_rht_applied, ) colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, @@ -661,6 +675,7 @@ def create_2x( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=colwise_has_rht_applied, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -680,6 +695,8 @@ def create( group_sizes: jnp.ndarray = None, original_shape: Tuple[int] = None, group_axis: int = 0, + rowwise_has_rht_applied: bool = False, + colwise_has_rht_applied: bool = False, ): """Creates a scaled tensor based on the quantization axis. @@ -696,10 +713,14 @@ def create( group_sizes: Array containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) + rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) Returns: Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout """ + assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet" + if q_layout == QuantizeLayout.ROWWISE_COLWISE: return ScaledTensorFactory.create_2x( data, @@ -715,6 +736,8 @@ def create( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + rowwise_has_rht_applied=rowwise_has_rht_applied, + colwise_has_rht_applied=colwise_has_rht_applied, ) is_colwise = q_layout == QuantizeLayout.COLWISE @@ -731,6 +754,7 @@ def create( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=colwise_has_rht_applied, ) return ScaledTensorFactory.create_1x( @@ -745,6 +769,7 @@ def create( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=rowwise_has_rht_applied, ) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 8eeaca4cc..adb67e358 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -238,6 +238,19 @@ def num_of_devices(): return len(jax.devices()) +def get_num_devices_in_mesh(mesh=None): + """ + Get the number of devices in the given mesh. + If the mesh is None, it would be replaced + by the global mesh. + """ + if mesh is None: + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + if mesh.empty: + return 1 + return np.prod(list(mesh.shape.values())) + + def get_mesh_axis_size(axis, mesh=None): """ Get the axis size of the given mesh. From 2ac3c16876fe3bcd4866f8a62251802ea5530888 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 22 Oct 2025 13:11:31 -0700 Subject: [PATCH 077/141] [JAX] Defer TE/JAX cublas shape check on fp8 gemms until lowering (#2292) Defer cublas check on fp8 gemms until lowering Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 778f77c0d..72bee251c 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -470,29 +470,6 @@ def _dims_are_consecutive(dims): f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" ) - lhs_axis_boundary = get_lhs_axis_boundary(lhs_contracting_dims, lhs_is_transposed) - lhs_contracting_size = ( - reduce(operator.mul, lhs.shape[lhs_axis_boundary:]) - if lhs_is_transposed - else reduce(operator.mul, lhs.shape[:lhs_axis_boundary]) - ) - assert_cublas_requirements( - scaling_mode, - lhs_contracting_size, - "LHS", - ) - rhs_axis_boundary = get_rhs_axis_boundary(rhs_contracting_dims, rhs_is_transposed) - rhs_contracting_size = ( - reduce(operator.mul, rhs.shape[:rhs_axis_boundary]) - if rhs_is_transposed - else reduce(operator.mul, rhs.shape[rhs_axis_boundary:]) - ) - assert_cublas_requirements( - scaling_mode, - rhs_contracting_size, - "RHS", - ) - # Determine output shape and dtype assert ( dtypes.canonicalize_dtype(out_dtype).itemsize > 1 @@ -601,6 +578,29 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) + lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) + lhs_contracting_size = ( + reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + if lhs_transposed + else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + ) + assert_cublas_requirements( + scaling_mode, + lhs_contracting_size, + "LHS", + ) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + rhs_contracting_size = ( + reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + if rhs_transposed + else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + ) + assert_cublas_requirements( + scaling_mode, + rhs_contracting_size, + "RHS", + ) + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { "scaling_mode": int(scaling_mode.value), From 66acb8e97baa095c8a6e0001bc27aca4f6a8574e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 22 Oct 2025 20:33:49 -0400 Subject: [PATCH 078/141] Include TE core headers in final build (#2291) Include TE core headers in build Signed-off-by: Kirthi Shankar Sivamani --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) create mode 100644 MANIFEST.in diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 000000000..c34025772 --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1 @@ +recursive-include transformer_engine/common/include *.* From eb34783cb774438a4367e45d478744d2799e1a7f Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Wed, 22 Oct 2025 22:31:08 -0700 Subject: [PATCH 079/141] Overhaul the compilation for the arch-specific features (#2279) * Added sm_120f to the build Signed-off-by: Przemek Tredak * Change the arch specific handling Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * Support for CUDA<12.9 Signed-off-by: Przemek Tredak * Moved through the rest of the files Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * Common cases Signed-off-by: Przemek Tredak * Remove pure 100 from the list Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * CMake changes, (not yet working) Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * Do not pass the arch-specific thing from build_tools Signed-off-by: Przemek Tredak * Fix Signed-off-by: Przemek Tredak * Moved some of the files to arch-specific compilation Signed-off-by: Przemek Tredak * Fix and also changing the order of compilation to hopefully get the compilation time lower Signed-off-by: Przemek Tredak * Fix for the files overwriting custom compile properties Signed-off-by: Przemek Tredak * Actually make this whole thing work Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add space to the error message Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak * Apply suggestions from code review Co-authored-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Signed-off-by: Przemyslaw Tredak * Fixes from review Signed-off-by: Przemek Tredak * Changing the naming to be more intuitive Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add missing cassert include for device-side asserts Signed-off-by: Przemek Tredak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Przemek Tredak Signed-off-by: Przemyslaw Tredak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> --- build_tools/utils.py | 6 +- transformer_engine/common/CMakeLists.txt | 206 +++++++++--- .../hadamard_transform_cast_fusion.cu | 27 +- ...quantize_transpose_vector_blockwise_fp4.cu | 76 ++--- .../common/util/nvfp4_transpose.cuh | 290 ++++++++-------- transformer_engine/common/util/ptx.cuh | 310 +++++++++++++++--- transformer_engine/common/utils.cuh | 1 + 7 files changed, 610 insertions(+), 306 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index 296f928b7..395b41261 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -257,11 +257,9 @@ def cuda_archs() -> str: if archs is None: version = cuda_version() if version >= (13, 0): - archs = "75;80;89;90;100;100a;103a;120" - elif version >= (12, 9): - archs = "70;80;89;90;100;100a;103a;120" + archs = "75;80;89;90;100;120" elif version >= (12, 8): - archs = "70;80;89;90;100;100a;120" + archs = "70;80;89;90;100;120" else: archs = "70;80;89;90" return archs diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e6be47686..175abd353 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -5,15 +5,6 @@ cmake_minimum_required(VERSION 3.21) # Language options -if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) - elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) - else () - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) - endif() -endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) @@ -30,8 +21,62 @@ project(transformer_engine LANGUAGES CUDA CXX) # CUDA Toolkit find_package(CUDAToolkit REQUIRED) -if (CUDAToolkit_VERSION VERSION_LESS 12.0) - message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") +if (CUDAToolkit_VERSION VERSION_LESS 12.1) + message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}") +endif() + +# Process GPU architectures +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + endif() +endif() + +# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures +set(NVTE_GENERIC_ARCHS) +set(NVTE_SPECIFIC_ARCHS) + +# Check for architecture 100 +list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index) +if(NOT arch_100_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100") + list(APPEND NVTE_GENERIC_ARCHS "100") + list(APPEND NVTE_SPECIFIC_ARCHS "100a") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND NVTE_SPECIFIC_ARCHS "103a") + endif() +endif() + +# Check for architecture 101 (if we see this we are in toolkit <= 12.9) +list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index) +if(NOT arch_101_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101") + list(APPEND NVTE_GENERIC_ARCHS "101") + list(APPEND NVTE_SPECIFIC_ARCHS "101a") +endif() + +# Check for architecture 110 (if we see this we are in toolkit >= 13.0) +list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index) +if(NOT arch_110_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110") + list(APPEND NVTE_GENERIC_ARCHS "110") + list(APPEND NVTE_SPECIFIC_ARCHS "110f") +endif() + +# Check for architecture 120 +list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index) +if(NOT arch_120_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120") + list(APPEND NVTE_GENERIC_ARCHS "120") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND NVTE_SPECIFIC_ARCHS "120f") + else() + list(APPEND NVTE_SPECIFIC_ARCHS "120a") + endif() endif() # cuDNN frontend API @@ -78,9 +123,28 @@ endif() # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) -list(APPEND transformer_engine_SOURCES +set(transformer_engine_cpp_sources) +set(transformer_engine_cuda_sources) +set(transformer_engine_cuda_arch_specific_sources) + +list(APPEND transformer_engine_cpp_sources cudnn_utils.cpp transformer_engine.cpp + fused_attn/fused_attn.cpp + gemm/config.cpp + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/rmsnorm/rmsnorm_api.cpp + util/cuda_driver.cpp + util/cuda_nvml.cpp + util/cuda_runtime.cpp + util/multi_stream.cpp + util/rtc.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/comm_gemm_overlap.cpp) + +list(APPEND transformer_engine_cuda_sources common.cu multi_tensor/adam.cu multi_tensor/compute_scale.cu @@ -92,40 +156,23 @@ list(APPEND transformer_engine_SOURCES transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu - transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu - activation/gelu.cu dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu - activation/relu.cu - activation/swiglu.cu fused_attn/fused_attn_fp8.cu - fused_attn/fused_attn.cpp fused_attn/utils.cu - gemm/config.cpp gemm/cublaslt_gemm.cu - gemm/cutlass_grouped_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_api.cpp normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu - util/cast.cu util/padding.cu - util/cuda_driver.cpp - util/cuda_nvml.cpp - util/cuda_runtime.cpp - util/multi_stream.cpp - util/rtc.cpp swizzle/swizzle.cu swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu @@ -139,12 +186,58 @@ list(APPEND transformer_engine_SOURCES recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu recipe/nvfp4.cu + comm_gemm_overlap/userbuffers/userbuffers.cu) + +list(APPEND transformer_engine_cuda_arch_specific_sources + gemm/cutlass_grouped_gemm.cu + util/cast.cu + activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu hadamard_transform/hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) + hadamard_transform/hadamard_transform_cast_fusion.cu) + +# Compiling the files with the worst compilation time first to hopefully overlap +# better with the faster-compiling cpp files +list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources} + ${transformer_engine_cuda_sources} + ${transformer_engine_cpp_sources}) + +# Set compile options for CUDA sources with generic architectures +foreach(cuda_source IN LISTS transformer_engine_cuda_sources) + set(arch_compile_options) + foreach(arch IN LISTS NVTE_GENERIC_ARCHS) + list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}") + endforeach() + + if(arch_compile_options) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_OPTIONS ${arch_compile_options} + ) + endif() +endforeach() + +# Set compile options for CUDA sources with specific architectures +foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) + set(arch_compile_options) + foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS) + list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}") + endforeach() + + if(arch_compile_options) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_OPTIONS ${arch_compile_options} + ) + endif() +endforeach() if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES @@ -249,28 +342,35 @@ target_include_directories(transformer_engine PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/string_headers") # Compiler options -set_source_files_properties(fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") +set(nvte_sources_with_fast_math) +list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu + multi_tensor/adam.cu + multi_tensor/compute_scale.cu + multi_tensor/l2norm.cu + multi_tensor/scale.cu + multi_tensor/sgd.cu + fused_attn/flash_attn.cu + fused_attn/context_parallel.cu + fused_attn/kv_cache.cu) + option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) - set_source_files_properties(activation/gelu.cu - activation/relu.cu - activation/swiglu.cu - util/cast.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") + list(APPEND nvte_sources_with_fast_math activation/gelu.cu + activation/relu.cu + activation/swiglu.cu + util/cast.cu) endif() + +foreach(cuda_source IN LISTS nvte_sources_with_fast_math) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_OPTIONS "--use_fast_math") +endforeach() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index ce191b5ff..263a32623 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -97,22 +97,23 @@ cutlass::Array StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { using result_type = cutlass::Array; result_type output; -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - auto output_ptr = reinterpret_cast(&output); - asm volatile( \ - "{\n" \ - "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ - "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ - "}" \ - : "=h"(output_ptr[0]), + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + auto output_ptr = reinterpret_cast(&output); + asm volatile( \ + "{\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ + "}" \ + : "=h"(output_ptr[0]), "=h"(output_ptr[1]) - : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1])); -#else - NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + } else { + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } return output; } diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index eced2c4bb..fed18c51f 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -264,48 +264,50 @@ __device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, s __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( const float2 in01, const float2 in23, const uint32_t rbits) { -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - uint16_t out_4x; - asm volatile( - "{\n" - "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" - "}" - : "=h"(out_4x) - : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); - return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - uint16_t dummy = 0; - return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + uint16_t out_4x; + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + "}" + : "=h"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt.rs PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); + } } __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, const float2 in23, const uint32_t rbits) { -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - // NOTE: rbits unused for rn. - uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. - asm volatile( - "{\n" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); - return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - uint16_t dummy = 0; - return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; + if constexpr (has_fp4) { + // NOTE: rbits unused for rn. + uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + asm volatile( + "{\n" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); + } } template diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh index 712b557c5..45fa29f0e 100644 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -15,10 +15,9 @@ #include #include -#if CUDA_VERSION > 12080 +#if FP4_TYPE_SUPPORTED #include -#endif // CUDA_VERSION > 12080 - +#endif // FP4_TYPE_SUPPORTED #include #include "../common.h" @@ -30,7 +29,7 @@ namespace transformer_engine { -#if CUDA_VERSION > 12080 +#if FP4_TYPE_SUPPORTED namespace nvfp4_transpose { using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + @@ -152,89 +151,89 @@ __device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int return rbits; } -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( const uint64_t in_4x, const float2 scale, const uint32_t rbits) { uint16_t out_4x = 0; -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } return *reinterpret_cast(&out_4x); } __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, const float2 scale, const uint32_t rbits) { - // NOTE: rbits unused for rn. + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale))); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } return reinterpret_cast(&out_4x)[0]; } @@ -252,34 +251,35 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { uint16_t out_4x = 0; -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale)), "r"(rbits)); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } return *reinterpret_cast(&out_4x); } @@ -287,40 +287,41 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 const float2 in23, const float2 scale, const uint32_t rbits) { - // NOTE: rbits unused for rn. + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale))); -#else - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); -#endif // CUDA_ARCH_HAS_FEATURE_SM10X_ALL + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } return reinterpret_cast(&out_4x)[0]; } @@ -335,8 +336,6 @@ __device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, c } } -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - template __global__ void __launch_bounds__(THREADS_NUM) @@ -1380,18 +1379,13 @@ __global__ void __launch_bounds__(THREADS_NUM) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } } // namespace nvfp4_transpose -#endif // CUDA_VERSION > 12080 - -// Compile-time flag to choose kernel variant -#ifndef USE_2D_NVFP4_KERNEL -#define USE_2D_NVFP4_KERNEL 0 -#endif +#endif // FP4_TYPE_SUPPORTED template void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, const QuantizationConfig *quant_config, cudaStream_t stream) { -#if CUDA_VERSION > 12080 +#if FP4_TYPE_SUPPORTED bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to @@ -1509,7 +1503,7 @@ void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *o });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // CUDA_VERSION > 12080 +#endif // FP4_TYPE_SUPPORTED } } // namespace transformer_engine diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 85717afdf..aeac2b4a2 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -18,44 +18,165 @@ #include #endif // CUDA_VERSION >= 12080 +#include "common/utils.cuh" + namespace transformer_engine { + namespace ptx { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +template +struct ArchSpecific { + constexpr static int id = N * 10; + + template + constexpr static bool compatible() { + if constexpr (CurrentArch == id) { + static_assert(ArchSpecific == CurrentArch, + "Compiled for the generic architecture, while utilizing arch-specific " + "features. Please compile for smXXXa architecture instead of smXXX " + "architecture."); + return true; + } else { + return false; + } + } +}; + +template +struct FamilySpecific { + constexpr static int id = N * 10; + + template + constexpr static bool compatible() { + if constexpr ((CurrentArch / 100) == (id / 100)) { + static_assert(FamilySpecific == CurrentArch, + "Compiled for the generic architecture, while utilizing family-specific " + "features. Please compile for smXXXf architecture instead of smXXX " + "architecture."); + return true; + } else { + return false; + } + } +}; + +template +constexpr bool is_supported_arch() { + if constexpr (T::template compatible()) { + return true; + } else if constexpr (sizeof...(U) != 0) { + return is_supported_arch(); + } else { + return false; + } +} + +#if CUDA_VERSION < 12090 +#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL) +#define __CUDA_ARCH_SPECIFIC__ 900 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 900 +#endif +#if __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) +#define __CUDA_ARCH_SPECIFIC__ 1000 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1000 +#endif +#if __CUDA_ARCH_HAS_FEATURE__(SM101_ALL) +#define __CUDA_ARCH_SPECIFIC__ 1010 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1010 +#endif +#if __CUDA_ARCH_HAS_FEATURE__(SM120_ALL) +#define __CUDA_ARCH_SPECIFIC__ 1200 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1200 +#endif +#endif + +#ifdef __CUDA_ARCH__ +#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__; +#else +#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = 0; +#endif + +#ifdef __CUDA_ARCH_SPECIFIC__ +#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = __CUDA_ARCH_SPECIFIC__; +#else +#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = 0; +#endif + +#ifdef __CUDA_ARCH_FAMILY_SPECIFIC__ +#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = __CUDA_ARCH_FAMILY_SPECIFIC__; +#else +#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = 0; +#endif + +#define NVTE_CUDA_ARCH_MATCHES(...) \ + [&] { \ + __NVTE_CURRENT_ARCH__ \ + __NVTE_ARCH_SPECIFIC__ \ + __NVTE_ARCH_FAMILY_SPECIFIC__ \ + return transformer_engine::ptx::is_supported_arch(); \ + }(); + +#define ARCH_BLACKWELL_FAMILY \ + NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>, ptx::FamilySpecific<110>, \ + ptx::FamilySpecific<120>) +#define ARCH_HAS_STOCHASTIC_ROUNDING \ + NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval __device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_arrive is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile("fence.mbarrier_init.release.cluster;"); +#else + NVTE_DEVICE_ERROR("fence_mbarrier_init_release_cluster is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // global -> shared::cluster __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); // triggers async copy, i.e. the thread continues until wait() on mbarrier @@ -67,6 +188,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_global_to_shared is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -74,6 +198,7 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); // triggers async copy, i.e. the thread continues until wait() on mbarrier @@ -85,9 +210,13 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_global_to_shared is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t waitComplete; asm volatile( "{\n\t .reg .pred P_OUT; \n\t" @@ -98,15 +227,21 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons : "r"(mbar_ptr), "r"(parity) : "memory"); return static_cast(waitComplete); +#else + NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + return true; } __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { } -} - +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; @@ -121,55 +256,53 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { return __int_as_float(biased_exp << FP32_MANTISSA_BITS); } -#define CUDA_ARCH_HAS_FEATURE_SM10X_ALL \ - ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM103_ALL))) - __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { -#if CUDA_ARCH_HAS_FEATURE_SM10X_ALL - - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; - } - if (isinf(val)) { - return 0xFE; - } - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); + } else { + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; } - return exponent; -#endif } -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, const uint64_t *src_shmem, const uint32_t size) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), "r"(src_shmem_ptr), "r"(size) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_shared_to_global is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -177,51 +310,93 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, uint64_t *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group __device__ __forceinline__ void cp_async_bulk_wait_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group template __device__ __forceinline__ void cp_async_bulk_wait_group_read() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 1;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 2;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 4;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group __device__ __forceinline__ void cp_async_bulk_commit_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.commit_group;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_commit_group is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // Proxy fence (bi-directional): -__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } +__device__ __forceinline__ void fence_proxy_async() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.async;"); +#else + NVTE_DEVICE_ERROR("fence_proxy_async is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} __device__ __forceinline__ void fence_proxy_async_shared_cta() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("fence.proxy.async.shared::cta;"); +#else + NVTE_DEVICE_ERROR("fence_proxy_async_shared_cta is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template @@ -282,15 +457,6 @@ static_assert(sizeof(fp4e2m1x2) == 1); static_assert(sizeof(fp4e2m1x4) == 2); #endif // CUDA_VERSION >= 12080 -// cvt.rn.satfinite.e2m1x2.f32 d, a, b; // Convert two FP32 values to two packed e2m1 - -// cvt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 introduced in PTX ISA version 8.6. - -// vt.rn.satfinite{.relu}.{e2m1x2/e2m3x2/e3m2x2/ue8m0x2}.f32 is supported on following architectures: -// sm_100a -// sm_101a -// sm_120a - // When converting to .e2m1x2 data formats, the destination operand d has .b8 type. // When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, // and the converted values are packed in the destination operand d such that the value @@ -313,6 +479,7 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair; \n\t" @@ -325,10 +492,14 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, : "=h"(reinterpret_cast(out)) : "l"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair; \n\t" @@ -341,9 +512,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, : "=h"(reinterpret_cast(out)) : "l"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -363,9 +538,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -385,9 +564,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -407,9 +590,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -429,24 +616,33 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" : "=r"(reinterpret_cast(dst)) : "r"(reinterpret_cast(p1)), "r"(reinterpret_cast(p2))); +#else + NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(reinterpret_cast(dst)) : "r"(reinterpret_cast(p1)), "r"(reinterpret_cast(p2))); +#else + NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - } // namespace ptx namespace { @@ -464,6 +660,8 @@ __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool i } // Syncthreads so initialized barrier is visible to all threads. __syncthreads(); +#else + NVTE_DEVICE_ERROR("initialize_barriers is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -479,6 +677,8 @@ __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_m ptx::mbarrier_invalid(&mbar[iter]); } } +#else + NVTE_DEVICE_ERROR("destroy_barriers is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -498,6 +698,8 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_1d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -517,6 +719,8 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -543,6 +747,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_2d_to_sharedx2 is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -572,6 +778,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx3( // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_2d_to_sharedx3 is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index bc764ac74..2d37e9c85 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -16,6 +16,7 @@ #endif #if !defined(__CUDACC_RTC__) +#include #include #else // Importing C++ standard headers is a pain with NVRTC From e2f2a0b4ef206af541c903262476db8cbfab3fb8 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 23 Oct 2025 12:34:50 -0700 Subject: [PATCH 080/141] [JAX] Make SR rng state always 2D (num_devices, 4) to fix partitioning issue (#2294) * Make SR rng state always 2D (num_devices, 4) Signed-off-by: Jeremy Berchtold * fix pure-jax impl Signed-off-by: Jeremy Berchtold * fix test shape Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 2 +- transformer_engine/jax/quantize/helper.py | 6 ++---- transformer_engine/jax/quantize/quantizer.py | 18 +++++++++++------- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1217ebf65..11ff9d061 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -876,7 +876,7 @@ def _sample_sr_qdq( for i in range(num_samples): iter_key = jax.random.fold_in(key, i) sr_rng_state = jax.random.randint( - iter_key, (4,), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 + iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 ) quantizer = QuantizerFactory.create( q_dtype=q_dtype, diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index e8b33c1d1..d5093e70e 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -631,10 +631,8 @@ def _make_stochastic_rounding_rng_state( ) sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash) - # Generate 4 random uint32 values from the JAX PRNG key - shape = (4,) - if get_num_devices_in_mesh() > 1: - shape = (get_num_devices_in_mesh(), 4) + # Generate 4 random uint32 values per device from the JAX PRNG key + shape = (get_num_devices_in_mesh(), 4) sr_jax_rng_state = jax.random.randint( sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 ).view(jnp.uint32) diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index d138b58da..eb2b7b592 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -34,6 +34,7 @@ TensorSource, ) from .device_utils import is_fp8_gemm_with_all_layouts_supported +from ..sharding import get_num_devices_in_mesh __all__ = [ "QuantizeLayout", @@ -633,9 +634,11 @@ def _apply_stochastic_rounding(self, x): assert ( self.stochastic_rounding_rng_state is not None ), "Stochastic rounding RNG state is not initialized" - assert self.stochastic_rounding_rng_state.shape == ( - 4, - ), "Stochastic rounding RNG state must be of shape (4,)" + expected_sr_rng_state_shape = (get_num_devices_in_mesh(), 4) + assert self.stochastic_rounding_rng_state.shape == expected_sr_rng_state_shape, ( + "Stochastic rounding RNG state must be of shape (num_devices_in_mesh, 4). Expected" + f" {expected_sr_rng_state_shape}, but got {self.stochastic_rounding_rng_state.shape}" + ) assert ( self.stochastic_rounding_rng_state.dtype == jnp.uint32 ), "Stochastic rounding RNG state must be of dtype uint32" @@ -643,14 +646,15 @@ def _apply_stochastic_rounding(self, x): # Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s key_bits = jnp.array( [ - self.stochastic_rounding_rng_state[0], - self.stochastic_rounding_rng_state[1], + # only take the first device's RNG state as the pure-JAX stochastic rounding impl only uses a single-device + self.stochastic_rounding_rng_state[0][0], + self.stochastic_rounding_rng_state[0][1], ], dtype=jnp.uint32, ) key = jax.random.wrap_key_data(key_bits) - key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[2]) - key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[3]) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][2]) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][3]) abs_x = jnp.abs(x) sign_x = jnp.sign(x) From 021e1e6239a44c334390ba8baf1d759166dfcedc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 24 Oct 2025 01:46:52 +0200 Subject: [PATCH 081/141] [PyTorch Debug] Fix issue with microbatching + debug value caching (#2108) * fix perf issue Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_perf.py | 11 +++++++---- transformer_engine/pytorch/module/base.py | 8 +++++++- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index 2d4b62b23..ad40c31c0 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -28,13 +28,15 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) model = torch.nn.Sequential( te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") ).cuda() - NUM_ITERS = 18000 + NUM_ITERS = 1800 elif layer == "transformer": model = torch.nn.Sequential( te.TransformerLayer(1, 1, 1, name="transformer1"), te.TransformerLayer(1, 1, 1, name="transformer2"), ).cuda() - NUM_ITERS = 2000 + NUM_ITERS = 200 + + NUM_INVOCATIONS_PER_ITER = 10 x = torch.randn(1, 1, 1).cuda() @@ -45,8 +47,9 @@ def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs) time_start = time.time() for i in range(NUM_ITERS): - y = model(x) - y.sum().backward() + for _ in range(NUM_INVOCATIONS_PER_ITER): + y = model(x) + y.sum().backward() if debug_tools_initialized: debug_api.step() torch.cuda.synchronize() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 7f571ce01..53b9920a6 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1523,7 +1523,13 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_enabled_in_this_iteration = debug + else: + # If this is the same iteration as previous invocation of the module, + # we use the debug value from the first invocation in the iteration. + debug = self.debug_enabled_in_this_iteration + return debug def no_debug_features_active(self, quantizers): From 6273cede50f50f6e48314fddb9d22da2d16ef871 Mon Sep 17 00:00:00 2001 From: buptzyb Date: Fri, 24 Oct 2025 21:56:03 +0800 Subject: [PATCH 082/141] [PyTorch] Support delay_wgrad_compute cudagraph (#1948) * support cudagraph dw Signed-off-by: Robin Zhang * fix lint Signed-off-by: Robin Zhang * fix ci Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/graph.py | 92 +++++++++++++++---- transformer_engine/pytorch/module/base.py | 12 ++- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 2 +- 4 files changed, 85 insertions(+), 23 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 798d3209a..9af9fb887 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -322,14 +322,16 @@ def _make_graphed_callables( fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] + bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): - for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs): + for fwd_graph, bwd_graph, bwd_dw_graph in zip(fwd_graphs, bwd_graphs, bwd_dw_graphs): fwd_graph.register_generator_state(state) bwd_graph.register_generator_state(state) + bwd_dw_graph.register_generator_state(state) mempool = graph_pool_handle() if pool is None else pool @@ -366,21 +368,8 @@ def _make_graphed_callables( ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." # Filter the TE modules that cudagraph can access. - visited_te_modules = set() - - def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument - if isinstance(module, TransformerEngineBaseModule): - visited_te_modules.add(module) - # If forward is called on a BasicOperation directly the hook will run - elif isinstance(module, BasicOperation): - visited_te_modules.add(module) - # If forward is called on a te.ops.Sequential it is not called on its constituent ops - elif isinstance(module, Sequential): - assert module._module_groups is not None, "Should have been initialized by warmup" - for module_group in module._module_groups: - if isinstance(module_group, OperationFuser): - for basic_op in module_group._basic_ops: - visited_te_modules.add(basic_op) + visited_te_modules = {} + need_bwd_dw_graph = {} # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): @@ -388,6 +377,31 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] + + def hook_fn( + module, inputs, outputs, func_idx=func_idx + ): # pylint: disable=unused-argument + modules = set() + if isinstance(module, TransformerEngineBaseModule): + modules.add(module) + # If forward is called on a BasicOperation directly the hook will run + elif isinstance(module, BasicOperation): + modules.add(module) + # If forward is called on a te.ops.Sequential it is not called on its constituent ops + elif isinstance(module, Sequential): + assert ( + module._module_groups is not None + ), "Should have been initialized by warmup" + for module_group in module._module_groups: + if isinstance(module_group, OperationFuser): + for basic_op in module_group._basic_ops: + modules.add(basic_op) + if modules: + if func_idx not in visited_te_modules: + visited_te_modules[func_idx] = modules + else: + visited_te_modules[func_idx].update(modules) + for warmup_iter in range(num_warmup_iters): hooks = [] for module in func.modules(): @@ -432,6 +446,15 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument module_params_with_grad ) per_callable_static_input_surfaces[func_idx] = static_input_surface + + # Run wgrad. This is essential for some TE modules when they have + # delay_wgrad_compute enabled. + need_backward_dw = False + for module in visited_te_modules.get(func_idx, set()): + if hasattr(module, "need_backward_dw") and module.need_backward_dw(): + need_backward_dw = True + module.backward_dw() + need_bwd_dw_graph[func_idx] = need_backward_dw else: grad_inputs = None del outputs, grad_inputs @@ -514,6 +537,17 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument allow_unused=allow_unused_input, retain_graph=retain_graph_in_backward, ) + # If no one module needs the backward_dw, the bwd_dw_graph will be empty. + # So skip capturing it. + if need_bwd_dw_graph[per_callable_bwd_idx]: + bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] + with _graph_context_wrapper(bwd_dw_graph, pool=mempool): + for module in visited_te_modules[per_callable_bwd_idx]: + if ( + hasattr(module, "need_backward_dw") + and module.need_backward_dw() + ): + module.backward_dw() # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. @@ -582,10 +616,12 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument # Capture backward graphs in reverse order per_callable_static_grad_outputs = [] per_callable_static_grad_inputs = [] - for static_input_surface, static_outputs, bwd_graph in zip( + for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip( reversed(per_callable_static_input_surfaces), reversed(per_callable_static_outputs), reversed(bwd_graphs), + reversed(bwd_dw_graphs), + reversed(range(len(per_callable_static_input_surfaces))), ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( @@ -601,6 +637,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument allow_unused=allow_unused_input, retain_graph=retain_graph_in_backward, ) + if need_bwd_dw_graph[bwd_idx]: + with _graph_context_wrapper(bwd_dw_graph, pool=mempool): + for module in visited_te_modules[bwd_idx]: + if hasattr(module, "need_backward_dw") and module.need_backward_dw(): + module.backward_dw() # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that # don't require grad. I couldn't think of a slick one-liner for this pattern. @@ -732,9 +773,10 @@ def functionalized(*user_args, **user_kwargs): ) func = graph_callables[i] + te_modules = visited_te_modules.get(i, set()) if isinstance(func, torch.nn.Module): - def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd, te_modules): def new_fwd(*user_args, **user_kwargs): # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method @@ -743,7 +785,7 @@ def new_fwd(*user_args, **user_kwargs): if FP8GlobalStateManager.is_fp8_enabled(): fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() for m in func.modules(): - if m not in visited_te_modules: + if m not in te_modules: # Only Set the FP8 meta for the modules included by forward continue if isinstance(m, TransformerEngineBaseModule): @@ -780,7 +822,7 @@ def new_fwd(*user_args, **user_kwargs): return new_fwd - forward = make_graphed_forward(func, func.training, graphed, func.forward) + forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules) if _order is None: func.forward = forward ret.append(func) @@ -789,6 +831,16 @@ def new_fwd(*user_args, **user_kwargs): else: ret.append(graphed) + # Attach backward_dw as an attribute to the graphed callable. + def backward_dw( + need_backward_dw=need_bwd_dw_graph.get(i, False), + bwd_dw_graph=bwd_dw_graphs[i], + ): + if need_backward_dw: + bwd_dw_graph.replay() + + setattr(ret[-1], "backward_dw", backward_dw) + if just_one_callable: return ret[0] diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 53b9920a6..9b6ca9d9c 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -662,6 +662,7 @@ def __init__(self) -> None: self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] + self.wgrad_store = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -1481,12 +1482,21 @@ def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_re """ self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) + def need_backward_dw(self): + """ + Check if this module needs to execute the delayed weight gradient computation. + This method should be used at the beginning of self.backward_dw() to determine if it + should actually be executed or just return without doing anything. + User can also manually call this method to check that before calling into backward_dw(). + """ + return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() + def backward_dw(self): """ Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + if not self.need_backward_dw(): return with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): (wgrad, bgrad), _ = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index aae85e2ca..bba97554c 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -840,7 +840,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + if not self.need_backward_dw(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bae0f2825..ccf5dc095 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2211,7 +2211,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + if not self.need_backward_dw(): return with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() From 060811c93615c7f8f671bdd870e4fe292b997836 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 24 Oct 2025 08:02:59 -0700 Subject: [PATCH 083/141] [Common] Fix checks in quantize_transpose_vector_blockwise_fp4 (#2299) fix checks in unoptimized non-rht fp4 quantize kernel Signed-off-by: Jeremy Berchtold --- .../quantize_transpose_vector_blockwise_fp4.cu | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index fed18c51f..4735fdcbe 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -718,13 +718,11 @@ void quantize_transpose_vector_blockwise_fp4( // raise error if pow2_scale is true NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now"); - if (!return_identity && !return_transpose) { - return; - } + NVTE_CHECK(return_identity || return_transpose, + "At least one of return_identity or return_transpose must be true."); - if (use_2d_quantization && !return_identity) { - return; - } + NVTE_CHECK(return_identity || !use_2d_quantization, + "2D block quantization is only supported when return_identity is true."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -777,7 +775,7 @@ void quantize_transpose_vector_blockwise_fp4( input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY( - output.dtype, 2, OutputType, + return_identity ? output.dtype : output_t.dtype, 2, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); From 87cb26c63c4dc240a77d1b526374631d28810018 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Fri, 24 Oct 2025 17:01:51 -0700 Subject: [PATCH 084/141] [PyTorch] Add max_logit support for MuonClip (#2195) * add max_score for fused/unfused F16 non-CP Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * calculate max per head instead of max over all heads Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix fused attn max_score shape Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert FE to github Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update FE to 1.15.0-rc Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * reduce ew kernels; fix causal masks; add more tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fix to tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove logic for flash-attn Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * WIP: add CP support for p2p/a2a/all_gather Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor improvements of implementation/tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP: add thd support Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add thd to UnfusedDPA Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * more fixes for lint Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update to FE 1.15 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove unneeded changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable unfused for thd + pad_between_seqs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fixes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable thd for unfused until bug is fixed Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix all_gather Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix all gather Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename max_score to max_logit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix all_gather Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix all_gather Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * disable fused attn + thd Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- 3rdparty/cudnn-frontend | 2 +- .../attention/run_attention_with_cp.py | 15 +- tests/pytorch/attention/test_attention.py | 68 ++- .../attention/test_attention_with_cp.py | 6 +- tests/pytorch/utils.py | 3 + .../common/fused_attn/fused_attn.cpp | 80 ++-- .../fused_attn_f16_arbitrary_seqlen.cu | 410 ++++++++++++------ .../fused_attn_f16_arbitrary_seqlen.h | 46 +- .../common/fused_attn/fused_attn_fp8.cu | 6 +- transformer_engine/common/fused_attn/utils.h | 5 +- .../include/transformer_engine/fused_attn.h | 79 ++-- .../jax/csrc/extensions/attention.cpp | 32 +- .../dot_product_attention/backends.py | 69 ++- .../dot_product_attention/context_parallel.py | 79 +++- .../dot_product_attention.py | 15 + .../attention/dot_product_attention/utils.py | 91 ++++ .../pytorch/cpp_extensions/fused_attn.py | 18 + transformer_engine/pytorch/csrc/extensions.h | 4 +- .../pytorch/csrc/extensions/attention.cpp | 25 +- 19 files changed, 748 insertions(+), 305 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 80a8e4af4..0b1577c8c 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 80a8e4af4d89d33a2c59d51fcf9fda1c9d368cd4 +Subproject commit 0b1577c8c83401237d601d0d0db5210506705396 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 1edffaf48..5ed67c3d5 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -248,6 +248,7 @@ def run_dpa_with_cp( attn_mask_type=config.attn_mask_type, window_size=config.window_size, softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, ).cuda() if config.softmax_type != "vanilla": core_attn.softmax_offset.requires_grad = True @@ -308,6 +309,7 @@ def run_dpa_with_cp( fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() + max_logit = None with fp8_context: # q, k, v, out in FP8; dout in F16 out = core_attn( @@ -322,6 +324,8 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ) + if config.return_max_logit: + out, max_logit = out if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) @@ -400,6 +404,7 @@ def run_dpa_with_cp( fp8_context = nullcontext() # run attention + max_logit_ = None with fp8_context: # q, k, v, out in FP8; dout in F16 out_ = core_attn( @@ -414,6 +419,8 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ) + if config.return_max_logit: + out_, max_logit_ = out_ if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) @@ -495,15 +502,15 @@ def run_dpa_with_cp( ) atol, rtol, rmse_tol = get_tols(config, dtype) - tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_] - tensors_no_cp = [out, dq, dk, dv, d_softmax_offset] - names = ["out", "dq", "dk", "dv", "d_softmax_offset"] + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] names_cp = [x + "_cp" for x in names] names_no_cp = [x + "_no_cp" for x in names] is_fp8 = dtype == "fp8" for i, t in enumerate(tensors_no_cp): if t is not None: - if "softmax_offset" not in names[i]: + if "softmax_offset" not in names[i] and "max_logit" not in names[i]: if qkv_format == "bshd": compare_and_assert( t[:, 0], diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 3150c06ab..b05a0447c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -131,6 +131,11 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if qkv_format == "thd" and "padding" not in config.attn_mask_type: + config.attn_mask_type = ( + "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" + ) # Get backends is_training = True @@ -172,7 +177,7 @@ def test_dot_product_attention( # UnfusedDotProductAttention backend if unfused_attn_supported: - unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( + unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention( dtype, config, "UnfusedDotProductAttention", @@ -186,7 +191,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: if len(fused_attn_backends) == 1: - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -198,7 +203,7 @@ def test_dot_product_attention( ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -209,7 +214,7 @@ def test_dot_product_attention( is_training, ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -222,7 +227,7 @@ def test_dot_product_attention( # FlashAttention backend if flash_attn_supported: - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( + flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", @@ -243,6 +248,8 @@ def test_dot_product_attention( if unfused_attn_supported and fused_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + if config.return_max_logit: + torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols) for i, _ in enumerate(unfused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: @@ -266,6 +273,33 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False) +model_configs_max_logit = { + # test: ModelConfig(b, sq, hq, dqk) + "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), + "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), + "max_logit_4": ModelConfig( + 8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias" + ), + "max_logit_5": ModelConfig( + 8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0) + ), + "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_max_logit]) +@pytest.mark.parametrize("model", model_configs_max_logit.keys()) +@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"]) +def test_dpa_max_logit(dtype, model_configs, model, qkv_layout): + """Test DotProductAttention module with checkpointing""" + config = model_configs[model] + config.return_max_logit = True + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) + + model_configs_softmax = { # test: ModelConfig(b, sq, hq, dqk) "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), @@ -962,6 +996,8 @@ def _run_dot_product_attention( layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") + # tensor: with padding tokens + # tensor_orig: without padding tokens tensor_orig = tensor if qkv_format == "thd" and pad_between_seqs: tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1071,6 +1107,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: layer_number=1, attention_type=config.attn_type, softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() @@ -1108,16 +1145,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: alibi_slopes=alibi_slopes, fast_zero_fill=True, ) + max_logit = None + if config.return_max_logit: + out, max_logit = out if is_training: out.backward(d_out) + d_softmax_offset = None if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad + if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad, d_softmax_offset) + return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None, d_softmax_offset) + return out, max_logit, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1146,14 +1188,18 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset) + return ( + out_orig, + max_logit, + (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset), + ) else: - return out_orig, (None, None, None, d_softmax_offset) + return out_orig, max_logit, (None, None, None, d_softmax_offset) else: if is_training: - return out, (q.grad, k.grad, v.grad, d_softmax_offset) + return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None, d_softmax_offset) + return out, max_logit, (None, None, None, d_softmax_offset) model_configs_te_layer = { diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 2c7f9d857..e5c856acd 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -137,8 +137,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: ModelConfig(b, sq, hq, dqk) - "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA - "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA "cp_1_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA @@ -183,7 +183,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_formats = ["bshd", "sbhd", "thd"] cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] if test_essential: - configs = ["cp_1_0", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} dtypes = ["bf16", "fp8"] qkv_formats = ["sbhd", "thd"] diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 72a1b3b53..485c739c0 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -205,6 +205,7 @@ def __init__( window_size: Tuple[int, int] = (-1, -1), context_parallel: bool = False, cp_comm_type: str = "p2p", + return_max_logit=False, total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -233,6 +234,7 @@ def __init__( self.window_size = check_set_window_size(self.attn_mask_type, window_size) self.context_parallel = context_parallel self.cp_comm_type = cp_comm_type + self.return_max_logit = return_max_logit self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -318,6 +320,7 @@ def test(): is_training=is_training, inference_params=inference_params, softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, ) ( use_flash_attention, diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 77cd8d235..f6ee37d4c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool return_max_logit) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -187,7 +187,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000)) { + (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -216,7 +216,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) { + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) { flag_m512 = true; } if ( @@ -418,8 +418,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, - size_t max_seqlen, bool is_training, float attn_scale, - float dropout, NVTE_QKV_Layout qkv_layout, + size_t max_seqlen, bool is_training, bool return_max_logit, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, @@ -460,7 +460,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -474,10 +474,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, input_Bias, - input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, + input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -544,7 +544,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -602,7 +602,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { @@ -680,7 +680,8 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -695,12 +696,12 @@ void nvte_fused_attn_fwd_kvpacked( #if (CUDNN_VERSION >= 8903) fused_attn_arbitrary_seqlen_fwd_kvpacked( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, + output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -777,7 +778,7 @@ void nvte_fused_attn_bwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -832,18 +833,16 @@ void nvte_fused_attn_bwd_kvpacked( } } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -913,7 +912,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -928,12 +928,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, - window_size_right, input_Q, input_K, input_V, input_Bias, input_SoftmaxOffset, output_O, - Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, - wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -1008,7 +1008,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index ba0f84578..950ced61b 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,10 +53,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrSoftmaxStats, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, @@ -102,36 +102,40 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + bool generate_stats = !return_max_logit; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - num_pages_k, - num_pages_v, - page_size_k, - page_size_v, - max_pages_per_seq_k, - max_pages_per_seq_v, - bias_b, - bias_h, - scaling_factor, - is_training, - dropout_probability, - layout, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - true, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET}; + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, + max_pages_per_seq_k, + max_pages_per_seq_v, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + true, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + return_max_logit, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -141,7 +145,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // V std::shared_ptr, // attn_scale std::shared_ptr, // O - std::shared_ptr, // Stats + std::shared_ptr, // S1 + std::shared_ptr, // S2 std::shared_ptr, // bias std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q @@ -244,6 +249,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_is_inference(false) + .set_generate_stats(generate_stats) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); @@ -317,7 +323,36 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_sink_token(softmax_offset); } - auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); + std::shared_ptr Max, Sum_Exp; + if (is_ragged_q && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + } + if (return_max_logit) { + Max = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Max") + .set_dim({b, h, s_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Sum_Exp") + .set_dim({b, h, s_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Max->set_stride({h * s_q, s_q, 1, 1}); + Sum_Exp->set_stride({h * s_q, s_q, 1, 1}); + } + sdpa_options.set_logit_max(Max); + sdpa_options.set_score_sum_exp(Sum_Exp); + } + + auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, @@ -332,17 +367,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - offset_stats = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); + if (!return_max_logit) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } } std::tuple, // Q @@ -351,7 +382,8 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // attn_scale std::shared_ptr> // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = std::make_tuple(Stats); + auto Stats_tuple = + generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto softmax_offset_tuple = is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); @@ -384,7 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, softmax_offset, seq_q, seq_kv, + auto [mha_graph, Q, K, V, attn_scale, O, S1, S2, bias, softmax_offset, seq_q, seq_kv, page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); @@ -417,9 +449,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // Build variant pack std::unordered_map, void *> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, - {V, devPtrV}, {attn_scale, &scaling_factor}, - {O, devPtrO}, {Stats, devPtrSoftmaxStats}}; + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, + {O, devPtrO}, {S1, devPtrS1}}; + + if (return_max_logit) { + variant_pack[S2] = devPtrS2; + } if (is_bias) { variant_pack[bias] = devPtrBias; @@ -561,35 +596,38 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - scaling_factor, - true, - dropout_probability, - layout, - bias_type, - mask_type, - softmax_type, - window_size_left, - window_size_right, - deterministic, - tensorType, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET}; + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + 0, + 0, + 0, + 0, + 0, + 0, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + deterministic, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + false, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1001,12 +1039,13 @@ void fused_attn_arbitrary_seqlen_bwd_impl( using namespace transformer_engine::fused_attn; void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool is_training, bool return_max_logit, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_QKV->data.dtype; @@ -1037,7 +1076,8 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( } void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrCuSeqlens = cu_seqlens->data.dptr; void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; @@ -1051,14 +1091,34 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Max->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_Max->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; + } + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1080,8 +1140,15 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1105,11 +1172,11 @@ void fused_attn_arbitrary_seqlen_fwd_qkvpacked( fused_attn_arbitrary_seqlen_fwd_impl( batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, - nullptr, devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); + return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, + get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1221,14 +1288,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1260,7 +1328,8 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( } void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1285,14 +1354,34 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1314,8 +1403,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1340,11 +1436,12 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1471,14 +1568,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { + bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; @@ -1488,7 +1585,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; @@ -1525,14 +1623,34 @@ void fused_attn_arbitrary_seqlen_fwd( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; } - output_S->data.dtype = DType::kFloat32; + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -1554,8 +1672,15 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS = output_S->data.dptr; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -1580,11 +1705,12 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrSoftmaxOffset, devPtrS, - devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, - devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); + return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index b9658b053..a3181c629 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -20,12 +20,13 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) void fused_attn_arbitrary_seqlen_fwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_QKV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + bool is_training, bool return_max_logit, float attn_scale, float p_dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, + const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, + Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, + const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, + cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_qkvpacked( size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, @@ -41,14 +42,15 @@ void fused_attn_arbitrary_seqlen_fwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, + float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, + NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, + const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd_kvpacked( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, @@ -68,14 +70,14 @@ void fused_attn_arbitrary_seqlen_fwd( size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, + const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 21c544491..7b85be972 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1710,7 +1710,8 @@ void fused_attn_fp8_fwd_impl_v1( qkv_tensor_type, o_tensor_type, cudnn_frontend::DataType_t::NOT_SET, - cudnn_frontend::DataType_t::NOT_SET}; + cudnn_frontend::DataType_t::NOT_SET, + false}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2038,7 +2039,8 @@ void fused_attn_fp8_bwd_impl_v1( qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type}; + dqkv_tensor_type, + false}; namespace fe = cudnn_frontend; using graph_and_tensors = diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index f03774f8e..72047a73f 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -115,20 +115,21 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type; + bool generate_max_sum_exp; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, - o_tensor_type, do_tensor_type, dqkv_tensor_type) < + o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type); + rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index a150978c4..518fad20d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -190,29 +190,30 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * - * \param[in] is_training Whether the model is in training mode. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] softmax_type The attention softmax type. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). + * \param[in] is_training Whether the model is in training mode. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool return_max_logit); /*! \brief Compute dot product attention with packed QKV input. * @@ -255,6 +256,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -266,13 +268,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked( - const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, - bool is_training, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -381,6 +386,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] max_seqlen_kv Max sequence length used for computing for KV. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -399,7 +405,7 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); @@ -520,6 +526,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. @@ -531,18 +538,16 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 9277569e1..ffc0706fe 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -22,7 +22,8 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); return backend; } @@ -179,17 +180,18 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); + false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), + nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_fwd_kvpacked( q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { nvte_fused_attn_fwd( q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), @@ -197,8 +199,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, + kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported QKVLayout."); @@ -276,7 +278,8 @@ static void FusedAttnForwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -294,7 +297,7 @@ static void FusedAttnForwardImpl( nvte_fused_attn_fwd_qkvpacked( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, + q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { @@ -308,8 +311,8 @@ static void FusedAttnForwardImpl( s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, + q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -323,7 +326,7 @@ static void FusedAttnForwardImpl( dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else { @@ -542,7 +545,8 @@ static void FusedAttnBackwardImpl( auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 6c19d868a..d4903be90 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -58,6 +58,8 @@ combine_and_quantize, combine_and_dequantize, print_quantizers, + ConvertTHDtoBSHD, + ConvertBSHDtoTHD, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, @@ -201,6 +203,7 @@ def __init__( attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, softmax_type: str = "vanilla", + return_max_logit: Optional[bool] = False, ) -> None: super().__init__() @@ -209,6 +212,7 @@ def __init__( self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number self.softmax_type = softmax_type + self.return_max_logit = return_max_logit def mask_func(x, y): return ( @@ -217,6 +221,7 @@ def mask_func(x, y): else attention_mask_func(x, y) ) + self.mask_func = mask_func self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) # Dropout. Note that for a single iteration, this layer will generate @@ -238,6 +243,8 @@ def forward( qkv_layout: str = "sbh3d", cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + max_seqlen_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + max_seqlen_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, @@ -261,6 +268,9 @@ def forward( if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) + # convert to sbhd + # training: bshd, thd + # inference: bshd, sbhd_2bshd, thd_2bshd if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now query_layer, key_layer, value_layer = [ @@ -269,9 +279,8 @@ def forward( if qkv_format == "sbhd_2bshd": key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]] - total_tokens, batch_size = None, None if qkv_format == "thd_2bshd": - total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0] + batch_size = key_layer.shape[0] query_layer = tex.convert_thd_to_bshd( query_layer, cu_seqlens_q, @@ -281,6 +290,26 @@ def forward( query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + if qkv_format == "thd": + assert cu_seqlens_q is not None and cu_seqlens_kv is not None + assert max_seqlen_q is not None and max_seqlen_kv is not None + query_layer = ConvertTHDtoBSHD.apply( + query_layer, + cu_seqlens_q, + max_seqlen_q, + ) + key_layer, value_layer = [ + ConvertTHDtoBSHD.apply( + x, + cu_seqlens_kv, + max_seqlen_kv, + ) + for x in [key_layer, value_layer] + ] + query_layer, key_layer, value_layer = [ + x.transpose(0, 1).contiguous() for x in [query_layer, key_layer, value_layer] + ] + batch_size, max_seqlen_q, max_seqlen_kv = ( query_layer.shape[1], query_layer.shape[0], @@ -426,6 +455,15 @@ def forward( matmul_result, None, None, dP_quantizer, "dP_quantizer", None ) + # max attention score + max_logit = None + if self.return_max_logit: + # matmul_result [b, np, sq, dk], max_logit [np] + max_logit = matmul_result + if attn_mask_type != "no_mask": + max_logit = self.mask_func(matmul_result, attention_mask) + max_logit = torch.amax(max_logit, dim=(0, 2, 3)) + # add attention sink to the last column: [b, np, sq, sk+1] if self.softmax_type != "vanilla": matmul_result = torch.cat( @@ -506,14 +544,13 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [tq, np, hn] - context_layer = tex.convert_bshd_to_thd( + context_layer = ConvertBSHDtoTHD.apply( context_layer, cu_seqlens_q, - total_tokens, ) # [tq, np, hn] --> [tq, hp] - context_layer = context_layer.view(total_tokens, -1) + context_layer = context_layer.view(context_layer.shape[0], -1) if fp8: # quantize and dequantize O to emulate FP8 @@ -529,6 +566,9 @@ def forward( if fp8_output: context_layer = O_quantizer(context_layer) + if self.return_max_logit: + return context_layer, max_logit + return context_layer @@ -1067,6 +1107,7 @@ def forward( softmax_offset, fp8_output, layer_number, + return_max_logit, ): # pylint: disable=missing-function-docstring @@ -1102,6 +1143,7 @@ def forward( # FP8 attention: torch.float16 or torch.bfloat16 out_nominal_dtype = q.dtype + max_logit = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] @@ -1129,7 +1171,7 @@ def forward( # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 # fp8_dtype = tex.DType.kFloat8E4M3 # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1205,7 +1247,7 @@ def forward( qkvo_tensors = (q, k, v, out) else: # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1233,6 +1275,7 @@ def forward( window_size, rng_gen, softmax_offset, + return_max_logit, ) out = out_ out_ret = out_ @@ -1327,10 +1370,12 @@ def forward( ctx.use_FAv2_bwd = use_FAv2_bwd ctx.deterministic = deterministic + if return_max_logit: + return out_ret, *max_logit return out_ret @staticmethod - def backward(ctx, d_out): + def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring # d_out is expected to be in FP8 if is_output_fp8=True, @@ -1574,6 +1619,7 @@ def backward(ctx, d_out): d_softmax_offset, None, None, + None, ) @@ -1614,6 +1660,7 @@ def __init__( layer_number: Optional[int] = None, deterministic: bool = False, softmax_type: str = "vanilla", + return_max_logit: Optional[bool] = False, ) -> None: super().__init__() @@ -1627,6 +1674,7 @@ def __init__( self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic self.softmax_type = softmax_type + self.return_max_logit = return_max_logit def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1846,6 +1894,7 @@ def forward( softmax_offset=softmax_offset, fp8_output=fp8_output, layer_number=self.layer_number, + return_max_logit=self.return_max_logit, ) else: with self.attention_dropout_ctx(): @@ -1881,7 +1930,11 @@ def forward( softmax_offset, fp8_output, self.layer_number, + self.return_max_logit, ) + if self.return_max_logit: + # ...hd -> ...(hd) + return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index e5ee8cc7d..f312cac79 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -617,6 +617,7 @@ def cp_p2p_fwd_fused_attn( rank, step, cp_size, + return_max_logit, q_part, k_part, v_part, @@ -693,7 +694,7 @@ def cp_p2p_fwd_fused_attn( fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step - out_per_step, aux_ctx_tensors = fused_attn_fwd( + out_per_step, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q_, max_seqlen_kv_, @@ -713,6 +714,7 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_q_padded=cu_seqlens_q_padded_, cu_seqlens_kv_padded=cu_seqlens_kv_padded_, **fp8_meta_kwargs, + return_max_logit=return_max_logit, ) if fp8: @@ -721,7 +723,9 @@ def cp_p2p_fwd_fused_attn( softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None - return out_per_step, softmax_lse_per_step, rng_states, attn_bias + if return_max_logit: + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None def cp_p2p_fwd_flash_attn( @@ -1086,6 +1090,7 @@ def forward( attn_bias, deterministic, use_fused_attention, + return_max_logit, fp8, fp8_meta, cp_group, @@ -1156,6 +1161,8 @@ def forward( amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] O_quantizer_per_step = [None for _ in range(cp_size)] + max_logit_per_step = [None for _ in range(cp_size)] + max_logit = None assert isinstance(k, q.__class__) and isinstance( v, q.__class__ @@ -1244,6 +1251,10 @@ def forward( q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + if return_max_logit: + max_logit_per_step = [ + torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size) + ] # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( @@ -1418,6 +1429,7 @@ def forward( rank, i, cp_size, + return_max_logit, ] else: flash_attn_inputs = [ @@ -1462,6 +1474,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1488,6 +1501,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1514,6 +1528,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn( *fused_attn_inputs, *prepare_outputs, section ) @@ -1541,6 +1556,7 @@ def forward( softmax_lse_per_step[i], rng_states[i], attn_biases[i], + max_logit_per_step[i], ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( @@ -1600,11 +1616,20 @@ def forward( softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), softmax_lse_per_step[i - 1], ) + if return_max_logit: + if i == 1: + max_logit = torch.clone(max_logit_per_step[0]) + else: + max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + if return_max_logit: + torch.distributed.all_reduce( + max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) second_half_lse_seqlen = None if causal and rank < (cp_size - 1): @@ -1682,6 +1707,10 @@ def forward( elif qkv_format == "sbhd": # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False + ) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) @@ -1811,10 +1840,12 @@ def forward( nvtx_range_pop(f"{nvtx_label}") + if return_max_logit: + return out_ret, max_logit return out_ret @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *_args): # pylint: disable=missing-function-docstring # add NVTX range @@ -2522,6 +2553,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -2577,6 +2609,7 @@ def forward( attn_bias, deterministic, use_fused_attention, + return_max_logit, window_size, cp_group, cp_stream, @@ -2682,6 +2715,8 @@ def forward( softmax_lse_per_step = [None, None] rng_states = [None, None] out = torch.empty_like(q) + max_logit_per_step = [None, None] + max_logit = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): @@ -2712,7 +2747,11 @@ def forward( # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + ( + out_per_step[i], + [softmax_lse_per_step[i], rng_states[i]], + *max_logit_, + ) = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv_, @@ -2732,7 +2771,10 @@ def forward( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], window_size=window_size_per_step[i], + return_max_logit=return_max_logit, ) + if return_max_logit: + max_logit_per_step[i] = max_logit_[0] else: fa_forward_args_thd = get_fa_args( True, @@ -2767,14 +2809,22 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + if return_max_logit and i == 0: + max_logit = torch.clone(max_logit_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + if return_max_logit: + max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + if return_max_logit: + torch.distributed.all_reduce( + max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) if use_fused_attention: if qkv_format == "bshd": @@ -2811,10 +2861,12 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") + if return_max_logit: + return out, max_logit return out @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *_args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3035,6 +3087,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3065,6 +3118,7 @@ def forward( attn_bias, deterministic, use_fused_attention, + return_max_logit, window_size, fp8, fp8_meta, @@ -3158,6 +3212,7 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] fwd_nominal_dtype = q.dtype fused_attn_backend = None + max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( dpa_utils.get_attention_quantizers(fp8, quantizers) @@ -3203,7 +3258,7 @@ def forward( Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) ] - out_, aux_ctx_tensors = fused_attn_fwd( + out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3226,6 +3281,7 @@ def forward( **fp8_meta_kwargs, softmax_type=softmax_type, softmax_offset=softmax_offset, + return_max_logit=return_max_logit, ) if isinstance(out_, Float8Tensor): out_fp8 = out_ @@ -3276,6 +3332,10 @@ def forward( out_ = flash_attn_a2a_communicate( out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + *max_logit, 0, cp_size, cp_group, cp_stream, False + ) if use_fused_attention: if qkv_format == "bshd": @@ -3362,10 +3422,12 @@ def forward( ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") + if return_max_logit: + return out_ret, max_logit return out_ret @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *_args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3599,6 +3661,7 @@ def backward(ctx, dout): None, None, None, + None, d_softmax_offset, None, ) @@ -3637,6 +3700,7 @@ def attn_forward_func_with_cp( softmax_offset=None, fp8_output=False, layer_number=1, + return_max_logit=False, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3784,6 +3848,7 @@ def attn_forward_func_with_cp( attn_bias, deterministic, use_fused_attention, + return_max_logit, ] if cp_comm_type in ["p2p", "a2a+p2p"]: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 6d9ce9a52..0d1c0b0c0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -255,6 +255,12 @@ class DotProductAttention(TransformerEngineBaseModule): where alpha is a learnable parameter in shape [h]. 'off-by-one' and 'learnable' softmax types are also called sink attention ('zero sink' and 'learnable sink'). + return_max_logit: Optional[bool], default = `False` + If true, returns the maximum attention score that can be used in a Muon optimizer to + rescale the Q and K projection weights (see `Muon is Scalable for LLM Training + `_). + max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv], + and max_logit is in shape [h]. Parallelism parameters ---------------------- @@ -311,6 +317,7 @@ def __init__( cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", + return_max_logit: Optional[bool] = False, ) -> None: super().__init__() @@ -394,6 +401,7 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.return_max_logit = return_max_logit self.softmax_type = softmax_type if self.softmax_type == "vanilla": @@ -431,6 +439,7 @@ def __init__( deterministic=self.deterministic, **attn_kwargs, softmax_type=self.softmax_type, + return_max_logit=self.return_max_logit, ) self.unfused_attention = UnfusedDotProductAttention( @@ -439,6 +448,7 @@ def __init__( **attn_kwargs, layer_number=layer_number, softmax_type=self.softmax_type, + return_max_logit=self.return_max_logit, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -1303,6 +1313,7 @@ def forward( fp8_meta=self.fp8_meta, inference_params=inference_params, softmax_type=self.softmax_type, + return_max_logit=self.return_max_logit, ) global _attention_backends if is_in_onnx_export_mode(): @@ -1502,6 +1513,8 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, @@ -1523,6 +1536,8 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 4cb39cda0..51279bd37 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -229,6 +229,8 @@ class AttentionParams: Inference-related parameters. See InferenceParams for details. softmax_type: str, default = "vanilla" The type of softmax operation. See DotProductAttention for details. + return_max_logit: bool, default = `False` + Whether to output max_logit. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -257,6 +259,7 @@ class AttentionParams: fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" + return_max_logit: bool = False def __eq__(self, other): """ @@ -330,6 +333,7 @@ def get_attention_backend( fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type + return_max_logit = attention_params.return_max_logit # Run config logger = logging.getLogger("DotProductAttention") @@ -477,6 +481,20 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0") use_fused_attention = False + # Filter: Return max_logit + if return_max_logit: + if use_flash_attention: + use_flash_attention = False + logger.debug("Disabling FlashAttention for max_logit") + if use_fused_attention and qkv_format == "thd": + use_fused_attention = False + logger.debug("Disabling FusedAttention for max_logit with qkv_format = thd") + if fp8 and fp8_meta["recipe"].fp8_dpa: + use_flash_attention = False + use_fused_attention = False + use_unfused_attention = False + logger.debug("Disabling all backends for max_logit with FP8 attention") + # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- @@ -913,6 +931,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt head_dim_v, window_size[0], window_size[1], + return_max_logit, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -1649,6 +1668,78 @@ def backward(ctx, grad_output): return None, None, _pack_tensor(indices, grad_output) +class ConvertTHDtoBSHD(torch.autograd.Function): + """ + Convert a tensor from qkv_format = thd to qkv_format = bshd. + """ + + @staticmethod + def forward(ctx, thd_tensor, cu_seqlens, max_seqlen): + # pylint: disable=missing-function-docstring + batch_size = cu_seqlens.shape[0] - 1 + if not thd_tensor.is_contiguous(): + thd_tensor = thd_tensor.contiguous() + bshd_tensor = tex.convert_thd_to_bshd( + thd_tensor, + cu_seqlens, + batch_size, + max_seqlen, + ) + ctx.save_for_backward(cu_seqlens) + ctx.num_tokens = thd_tensor.shape[0] + return bshd_tensor + + @staticmethod + def backward(ctx, bshd_tensor): + # pylint: disable=missing-function-docstring + (cu_seqlens,) = ctx.saved_tensors + if not bshd_tensor.is_contiguous(): + bshd_tensor = bshd_tensor.contiguous() + thd_tensor = tex.convert_bshd_to_thd( + bshd_tensor, + cu_seqlens, + ctx.num_tokens, + ) + return thd_tensor, None, None + + +class ConvertBSHDtoTHD(torch.autograd.Function): + """ + Convert a tensor from qkv_format = bshd to qkv_format = thd. + """ + + @staticmethod + def forward(ctx, bshd_tensor, cu_seqlens): + # pylint: disable=missing-function-docstring + num_tokens = cu_seqlens[-1] + max_seqlen = bshd_tensor.shape[1] + if not bshd_tensor.is_contiguous(): + bshd_tensor = bshd_tensor.contiguous() + thd_tensor = tex.convert_bshd_to_thd( + bshd_tensor, + cu_seqlens, + num_tokens, + ) + ctx.save_for_backward(cu_seqlens) + ctx.max_seqlen = max_seqlen + return thd_tensor + + @staticmethod + def backward(ctx, thd_tensor): + # pylint: disable=missing-function-docstring + (cu_seqlens,) = ctx.saved_tensors + batch_size = cu_seqlens.shape[0] - 1 + if not thd_tensor.is_contiguous(): + thd_tensor = thd_tensor.contiguous() + bshd_tensor = tex.convert_thd_to_bshd( + thd_tensor, + cu_seqlens, + batch_size, + ctx.max_seqlen, + ) + return bshd_tensor, None + + def get_qkv_format( qkv_layout: str = "bshd_bshd_bshd", inference_params: InferenceParams = None, diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index f80c001a1..eb43c75f6 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -139,6 +139,7 @@ def fused_attn_fwd( window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, + return_max_logit: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -216,6 +217,8 @@ def fused_attn_fwd( softmax_offset: torch.Tensor, default = None softmax offset tensor in shape [1, h_q, 1, 1]. See softmax_type in DotProductAttention for details. + return_max_logit: bool, default = False + whether to return the maximum attention score Returns ---------- @@ -246,6 +249,7 @@ def fused_attn_fwd( rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; [seed, offset], dtype uint64 + max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None """ if attn_scale is None: @@ -315,8 +319,22 @@ def fused_attn_fwd( softmax_offset, rng_gen, rng_elts_per_thread, + return_max_logit, ) + if return_max_logit: + qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + # thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] + # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + stats = output_tensors[1] + torch.log(output_tensors[2]) + amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3) + # Max -> max_logit [h] + max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) + aux_ctx_tensors = [stats] + aux_ctx_tensors.extend(output_tensors[3:]) + return output_tensors[0], aux_ctx_tensors, max_logit + # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d86a96959..79fb79842 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right); + int64_t window_size_right, bool return_max_logit); std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, @@ -94,7 +94,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread); + size_t rng_elts_per_thread, bool return_max_logit); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 344bc4ab0..f66c8aa61 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,11 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right) { + int64_t window_size_right, bool return_max_logit) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, - max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, + return_max_logit); return fused_attention_backend; } @@ -106,7 +107,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread) { + size_t rng_elts_per_thread, bool return_max_logit) { auto none = py::none(); // create QKV tensor wrappers @@ -228,8 +229,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -249,7 +251,9 @@ std::vector fused_attn_fwd( }; // allocate memory for nvte_aux_tensor_pack.tensors // f16_max512 : S [b, h, sq, skv] - // f16_arbitrary: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // f16_arbitrary: + // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] size_t i = 0; at::Tensor output_tensor; @@ -258,8 +262,8 @@ std::vector fused_attn_fwd( allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 has an additional softmax stats tensor, ZInv - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor + if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); @@ -285,8 +289,9 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory From d2945c6a571e3978677614d1fe08779966a5a4ef Mon Sep 17 00:00:00 2001 From: Tong Liu Date: Mon, 27 Oct 2025 16:10:58 +0800 Subject: [PATCH 085/141] [PyTorch] Use dummy wgrad in GroupedLinear (#2305) dummy wgrad Signed-off-by: tongliu Signed-off-by: Xin Yao Co-authored-by: Xin Yao --- .../pytorch/module/grouped_linear.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index bba97554c..4d6b2f23b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -13,6 +13,7 @@ from transformer_engine.common.recipe import Recipe from .base import ( + get_dummy_wgrad, get_multi_stream_cublas_workspace, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -447,18 +448,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None From d7c9777e611a90337fffb0482a62ee2b60ef0353 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 27 Oct 2025 18:13:30 -0400 Subject: [PATCH 086/141] Remove `nvidia-mathdx` dependency (#2295) * Remove nvidia-mathdx dep Signed-off-by: Kirthi Shankar Sivamani * Fix SR Signed-off-by: Kirthi Shankar Sivamani * Add comment Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/build.yml | 8 +- build_tools/wheel_utils/build_wheels.sh | 2 +- pyproject.toml | 3 +- transformer_engine/common/CMakeLists.txt | 23 ---- .../hadamard_transform_cast_fusion.cu | 18 +-- ...quantize_transpose_vector_blockwise_fp4.cu | 27 ++--- transformer_engine/common/util/curanddx.hpp | 106 ++++++++++++++++++ .../common/util/nvfp4_transpose.cuh | 28 +++-- 8 files changed, 145 insertions(+), 70 deletions(-) create mode 100644 transformer_engine/common/util/curanddx.hpp diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 506bc83f0..f40b28189 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,7 +19,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1 + pip install cmake==3.21.0 pybind11[global] ninja - name: 'Checkout' uses: actions/checkout@v3 with: @@ -43,7 +43,7 @@ jobs: run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 - pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1 + pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript - name: 'Checkout' uses: actions/checkout@v3 with: @@ -63,7 +63,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install pybind11[global] nvidia-mathdx==25.1.1 + run: pip install pybind11[global] - name: 'Checkout' uses: actions/checkout@v3 with: @@ -83,7 +83,7 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 + run: pip install torch pybind11[global] einops onnxscript - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 954a8f1c6..d0055b791 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -23,7 +23,7 @@ git checkout $TARGET_BRANCH git submodule update --init --recursive # Install deps -/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1 +/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel if $BUILD_METAPACKAGE ; then cd /TransformerEngine diff --git a/pyproject.toml b/pyproject.toml index 8692ad961..35a7c2072 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,8 +3,7 @@ # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "nvidia-mathdx==25.1.1", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" - diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 175abd353..e388dd794 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -98,28 +98,6 @@ set(CUTLASS_TOOLS_INCLUDE_DIR # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) -# NVIDIA MathDX include directory (from Python package install location) -if(NOT DEFINED MATHDX_INCLUDE_DIR) - execute_process( - COMMAND ${Python_EXECUTABLE} -m pip show nvidia-mathdx - OUTPUT_VARIABLE _PIP_SHOW_MATHDX - ERROR_VARIABLE _PIP_SHOW_MATHDX_ERR - RESULT_VARIABLE _PIP_SHOW_MATHDX_RES - OUTPUT_STRIP_TRAILING_WHITESPACE) - if(NOT _PIP_SHOW_MATHDX_RES EQUAL 0) - message(FATAL_ERROR "Failed to query 'nvidia-mathdx' with pip (using ${Python_EXECUTABLE}): ${_PIP_SHOW_MATHDX_ERR}") - endif() - string(REGEX MATCH "Location: ([^\n\r]+)" _MATHDX_LOC_MATCH "${_PIP_SHOW_MATHDX}") - if(NOT _MATHDX_LOC_MATCH) - message(FATAL_ERROR "Could not parse installation location for 'nvidia-mathdx'. Output was:\n${_PIP_SHOW_MATHDX}") - endif() - set(MATHDX_LOCATION "${CMAKE_MATCH_1}") - set(MATHDX_INCLUDE_DIR "${MATHDX_LOCATION}/nvidia/mathdx/include") -endif() -if(NOT EXISTS "${MATHDX_INCLUDE_DIR}") - message(FATAL_ERROR "MATHDX include directory not found at ${MATHDX_INCLUDE_DIR}. Set MATHDX_INCLUDE_DIR or ensure 'nvidia-mathdx' is installed for ${Python_EXECUTABLE}.") -endif() - # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) @@ -263,7 +241,6 @@ target_link_libraries(transformer_engine PUBLIC target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) -target_include_directories(transformer_engine PRIVATE ${MATHDX_INCLUDE_DIR}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu index 263a32623..12f02dba6 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -19,9 +19,9 @@ #include "common/common.h" #include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "curanddx.hpp" #include "cutlass/arch/barrier.h" #include "cutlass/cutlass.h" #include "cutlass/gemm/collective/builders/sm100_common.inl" @@ -38,15 +38,6 @@ namespace transformer_engine { namespace detail { namespace { -// Define a cuRANDDx descriptor -// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. -// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., -// if shared memory, if needed, is enough for the described problem, usually not applicable. - -// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + curanddx::SM<800>() + curanddx::Thread()); - - using namespace cute; using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor @@ -502,8 +493,9 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, // Initialize RNG for tile const size_t rng_sequence = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); uint4 random_uint4 = uint4{0, 0, 0, 0}; CUTLASS_PRAGMA_UNROLL @@ -511,7 +503,7 @@ rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); // auto acc_scale = acc_scales[v]; if constexpr (kEnableStochasticRounding) { - random_uint4 = dist.generate4(rng); + random_uint4 = rng.generate4(); output_frgs[v] = StochasticNumericConverter( cutlass::multiplies>{}( compute_frgs[v], diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index 4735fdcbe..b49a54fbd 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -17,9 +17,9 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/curanddx.hpp" #include "common/util/ptx.cuh" #include "common/utils.cuh" -#include "curanddx.hpp" namespace transformer_engine { @@ -33,14 +33,6 @@ using std::uint8_t; using transformer_engine::detail::TypeExtrema; -// Define a cuRANDDx descriptor -// Note curanddx::PhiloxRounds<4> means 4 rounds of philox4_32. If the operator is not specified, it will be default to 10. -// curanddx::SM<800>() does NOT mean the code can only run on SM 800. The operator is used for do some internal checks, e.g., -// if shared memory, if needed, is enough for the described problem, usually not applicable. -// curanddx doc: https://docs.nvidia.com/cuda/curanddx/index.html -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + - curanddx::SM<800>() + curanddx::Thread()); - // clang-format off /* @@ -209,12 +201,15 @@ __device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_ return global_encode_scale; } -__device__ __forceinline__ uint32_t get_rbits(RNG& rng, uint4& random_uint4, int& rnd_idx) { +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>& + rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4& random_uint4, int& rnd_idx) { if (rnd_idx == 4) { rnd_idx = 0; - curanddx::uniform_bits dist; - random_uint4 = dist.generate4(rng); + random_uint4 = rng.generate4(); } + // Treat uint4 as an array of 4x uint32_t elements for indexing const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); const uint32_t rbits = rbits_arr[rnd_idx++]; @@ -348,9 +343,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; - uint4 random_uint4 = kApplyStochasticRounding ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x diff --git a/transformer_engine/common/util/curanddx.hpp b/transformer_engine/common/util/curanddx.hpp new file mode 100644 index 000000000..4d7c90a01 --- /dev/null +++ b/transformer_engine/common/util/curanddx.hpp @@ -0,0 +1,106 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ + +namespace transformer_engine { +namespace curanddx { +namespace detail { + +inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U; +inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U; +inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U; +inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U; + +__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int* hip) { + *hip = __umulhi(a, b); + return a * b; +} + +__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1); + + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; +} + +template +__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) { + for (unsigned int i = 0; i < Rounds - 1; i++) { + c = single_round(c, k); // 1 + k.x += philox4x32_w32_0; + k.y += philox4x32_w32_1; + } + return single_round(c, k); // Rounds +} + +template +struct philox4x32_native_state { + static constexpr unsigned int rounds = Rounds; + + uint4 ctr; + uint2 key; + + __forceinline__ __device__ void philox_state_incr() { + if (++ctr.x) return; + if (++ctr.y) return; + if (++ctr.z) return; + ++ctr.w; + } + + __forceinline__ __device__ void philox_state_incr(size_t n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + + ctr.x += nlo; + if (ctr.x < nlo) nhi++; + + ctr.y += nhi; + if (nhi <= ctr.y) return; + if (++ctr.z) return; + ++ctr.w; + } + + __forceinline__ __device__ void philox_state_incr_hi(size_t n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + + ctr.z += nlo; + if (ctr.z < nlo) nhi++; + + ctr.w += nhi; + } + + // offset is the total # of 128bits generated with a single generate4() call + __forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); } + + __forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); } + + __forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) { + ctr = make_uint4(0, 0, 0, 0); + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + + skip_subsequence(subsequence); + skip_offset(offset); + } + + __forceinline__ __device__ uint4 generate4() { + auto tmp = multiple_rounds(ctr, key); + philox_state_incr(); + return tmp; + } +}; +} // namespace detail +} // namespace curanddx +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh index 45fa29f0e..629520aeb 100644 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ b/transformer_engine/common/util/nvfp4_transpose.cuh @@ -32,9 +32,6 @@ namespace transformer_engine { #if FP4_TYPE_SUPPORTED namespace nvfp4_transpose { -using RNG = decltype(curanddx::Generator() + curanddx::PhiloxRounds<10>() + - curanddx::SM<800>() + curanddx::Thread()); - using namespace ptx; using nvfp4_scale_t = fp8e4m3; @@ -139,12 +136,15 @@ __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const return global_encode_scale; } -__device__ __forceinline__ uint32_t get_rbits(RNG &rng, uint4 &random_uint4, int &rnd_idx) { +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> + &rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4 &random_uint4, int &rnd_idx) { if (rnd_idx == 4) { rnd_idx = 0; - curanddx::uniform_bits dist; - random_uint4 = dist.generate4(rng); + random_uint4 = rng.generate4(); } + // Treat uint4 as an array of 4x uint32_t elements for indexing const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); const uint32_t rbits = rbits_arr[rnd_idx++]; @@ -363,9 +363,11 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; - uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x @@ -874,9 +876,11 @@ __global__ void __launch_bounds__(THREADS_NUM) threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - RNG rng(rng_seed, rng_sequence, rng_offset); - curanddx::uniform_bits dist; - uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? dist.generate4(rng) : uint4{0, 0, 0, 0}; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x From a019c80a403fc0f596210237ccf6b954e392c260 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 27 Oct 2025 18:27:52 -0400 Subject: [PATCH 087/141] Submodule checkout during setup (#2293) * Add checks to submodule during setup and automatically checkout Signed-off-by: Kirthi Shankar Sivamani * fix import and formatting Signed-off-by: Kirthi Shankar Sivamani * provide envvar to skip submodule init in setup Signed-off-by: Kirthi Shankar Sivamani * Fix formatting Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/setup.py b/setup.py index a820265c3..ce3805d2e 100644 --- a/setup.py +++ b/setup.py @@ -6,6 +6,8 @@ from importlib import metadata import os +import shutil +import subprocess import time from pathlib import Path from typing import List, Tuple @@ -126,9 +128,64 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def git_check_submodules() -> None: + """ + Attempt to checkout git submodules automatically during setup. + + This runs successfully only if the submodules are + either in the correct or uninitialized state. + + Note to devs: With this, any updates to the submodules itself, e.g. moving to a newer + commit, must be commited before build. This also ensures that stale submodules aren't + being silently used by developers. + """ + + # Provide an option to skip these checks for development. + if bool(int(os.getenv("NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD", "0"))): + return + + # Require git executable. + if shutil.which("git") is None: + return + + # Require a .gitmodules file. + if not (current_file_path / ".gitmodules").exists(): + return + + try: + submodules = subprocess.check_output( + ["git", "submodule", "status", "--recursive"], + cwd=str(current_file_path), + text=True, + ).splitlines() + + for submodule in submodules: + # '-' start is for an uninitialized submodule. + # ' ' start is for a submodule on the correct commit. + assert submodule[0] in ( + " ", + "-", + ), ( + "Submodules are initialized incorrectly. If this is intended, set the " + "environment variable `NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD` to a " + "non-zero value to skip these checks during development. Otherwise, " + "run `git submodule update --init --recursive` to checkout the correct" + " submodule commits." + ) + + subprocess.check_call( + ["git", "submodule", "update", "--init", "--recursive"], + cwd=str(current_file_path), + ) + except subprocess.CalledProcessError: + return + + if __name__ == "__main__": __version__ = te_version() + git_check_submodules() + with open("README.rst", encoding="utf-8") as f: long_description = f.read() From 4cf2f12b408540513f857eb6f516eebef757979b Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 27 Oct 2025 16:49:28 -0700 Subject: [PATCH 088/141] Change the pyTorch installation to CUDA 13 in Build All GitHub action (#2308) Change the pyTorch installation to CUDA 13 in Build All GitHub action to match the version in the JAX container Signed-off-by: Przemek Tredak --- .github/workflows/build.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f40b28189..42c5f0342 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,7 +83,9 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript + run: | + pip install pybind11[global] einops onnxscript + pip install torch --index-url https://download.pytorch.org/whl/cu130 - name: 'Checkout' uses: actions/checkout@v3 with: From a8e4346ec6630bf808d8904a4fdb19cb6f54b48b Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 28 Oct 2025 09:19:50 -0400 Subject: [PATCH 089/141] [JAX] Use TE quantization when TE fused norm is disable (#2303) * jax norm + te quant Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/normalization.py | 38 +++++++++++++------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 90ab5fb7f..d09ce7ef7 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -27,7 +27,7 @@ NamedSharding, get_cudnn_version, ) -from .quantization import _quantize_dbias_impl, AmaxScope +from .quantization import quantize, AmaxScope from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp_tpsp, @@ -945,7 +945,7 @@ def layernorm_fwd( beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float, - quantizer: Optional[Quantizer], + quantizer: Optional[Quantizer] = None, amax_scope: AmaxScope = AmaxScope.LOCAL, transpose_batch_sequence: bool = False, output_amax_when_no_scaling: bool = False, @@ -975,7 +975,16 @@ def layernorm_fwd( - Reciprocal of the standard deviation of the input tensor. Shape: (..., 1) """ if not NormFwdPrimitive.enabled(): - return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon) + if quantizer is not None: + output = quantize( + output, + quantizer, + flatten_axis=-1, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + return (output, mu, rsigma) # TE/common does not support normalization with colwise only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: @@ -1029,7 +1038,7 @@ def layernorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=False, ) - out, _ = _quantize_dbias_impl( + out, _ = quantize( out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence ) return out, mu, rsigma @@ -1050,11 +1059,9 @@ def layernorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=True, ) - out, _ = _quantize_dbias_impl( + out = quantize( out, - is_dbias=False, quantizer=quantizer, - dq_dtype=x.dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, ) @@ -1219,7 +1226,16 @@ def rmsnorm_fwd( Shape: (..., 1) """ if not NormFwdPrimitive.enabled(): - return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon) + if quantizer is not None: + output = quantize( + output, + quantizer, + flatten_axis=-1, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + return (output, rsigma) # TE/common does not support normalization with colwise only quantization yet if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: @@ -1274,7 +1290,7 @@ def rmsnorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=False, ) - out, _ = _quantize_dbias_impl( + out = quantize( out.data, quantizer, amax_scope=amax_scope, @@ -1297,11 +1313,9 @@ def rmsnorm_fwd( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=True, ) - out, _ = _quantize_dbias_impl( + out = quantize( out, - is_dbias=False, quantizer=quantizer, - dq_dtype=x.dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, ) From c6cbcc85368adc03b23d4e55d1f258b4de19316a Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Wed, 29 Oct 2025 13:21:33 -0700 Subject: [PATCH 090/141] [Pytorch] Integrate GPT OSS Swiglu in TransformerLayer (#2312) * changes working Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add support for onnx, minor comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * greptile review comments Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/transformer.py Co-authored-by: Przemyslaw Tredak Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/module/layernorm_mlp.py Co-authored-by: Przemyslaw Tredak Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/transformer.py Co-authored-by: Przemyslaw Tredak Signed-off-by: vthumbe1503 * address review comments Signed-off-by: Varun Thumbe * address review comments Signed-off-by: Varun Thumbe * revert the name change Signed-off-by: Varun Thumbe --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak --- tests/pytorch/test_onnx_export.py | 2 +- tests/pytorch/test_sanity.py | 4 +- .../pytorch/module/layernorm_mlp.py | 56 +++++++++++++++---- transformer_engine/pytorch/transformer.py | 9 ++- 4 files changed, 57 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index f8b4d7481..2ce6eb82b 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -68,7 +68,7 @@ fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(None) -supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] +supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "clamped_swiglu"] all_normalizations = ["LayerNorm", "RMSNorm"] diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e283842ec..f12e80d4c 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -122,6 +122,7 @@ def nvfp4_vanilla(): "sreglu", "silu", "swiglu", + "clamped_swiglu", ] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -547,7 +548,7 @@ def test_sanity_layernorm_mlp( sigma = 0.023 init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - + activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702} block = LayerNormMLP( config.hidden_size, 4 * config.hidden_size, @@ -555,6 +556,7 @@ def test_sanity_layernorm_mlp( output_layer_init_method=output_layer_init_method, zero_centered_gamma=zero_centered_gamma, activation=activation, + activation_params=activation_params, normalization=normalization, params_dtype=dtype, device="cuda", diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ccf5dc095..889f545c1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -99,6 +99,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "sreglu": (tex.sreglu, tex.dsreglu, None), "silu": (tex.silu, tex.dsilu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), } if recipe.delayed() or recipe.mxfp8(): # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] @@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "sreglu": (tex.sreglu, tex.dsreglu, None), "silu": (tex.silu, tex.dsilu, tex.dbias_dsilu), "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling or custom quantization: [] @@ -135,6 +137,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "sreglu": (tex.sreglu, tex.dsreglu, None), "silu": (tex.silu, tex.dsilu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), } raise NotImplementedError(f"Unhandled recipe type {recipe}") @@ -199,6 +202,7 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, activation: str, + activation_params: Optional[dict], normalization: str, ub_overlap_ag: bool, ub_overlap_rs: bool, @@ -440,6 +444,7 @@ def forward( # ACTIVATION - sometimes activation is fused with the GEMM above. fc1_out_without_bias = None + act_params = activation_params or {} if bias_gelu_fusion: fc1_out = None @@ -449,7 +454,7 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs elif debug: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, None) + act_out = activation_func(fc1_out, None, **act_params) act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs @@ -457,19 +462,19 @@ def forward( recipe = FP8GlobalStateManager.get_fp8_recipe() if recipe.float8_block_scaling(): # tex.quantize does not support GELU fusion for blockwise - act_out = activation_func(fc1_out, None) + act_out = activation_func(fc1_out, None, **act_params) act_out = tex.quantize(act_out, fc2_input_quantizer) elif recipe.custom(): # tex.quantize does not support custom quantizers - act_out = activation_func(fc1_out, None) + act_out = activation_func(fc1_out, None, **act_params) act_out = fc2_input_quantizer(act_out) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) else: if fp8_calibration: - act_out = activation_func(fc1_out, None) + act_out = activation_func(fc1_out, None, **act_params) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -624,6 +629,7 @@ def forward( ctx.device = device ctx.activation_dtype = activation_dtype ctx.activation = activation + ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -1002,6 +1008,7 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # bias computation + act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False if ctx.fc1_grad_output_quantizer is not None: @@ -1015,7 +1022,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ctx.debug: dact_func = _act_func(ctx.activation)[1] - dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) + dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) elif ( @@ -1027,7 +1034,10 @@ def fc2_wgrad_gemm( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer + fc2_dgrad, + fc1_out.to(ctx.activation_dtype), + ctx.fc1_grad_output_quantizer, + **act_params, ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, @@ -1036,7 +1046,7 @@ def fc2_wgrad_gemm( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[1] dact = activation_func_bwd( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), None + fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision if ctx.fp8: @@ -1401,6 +1411,7 @@ def fc1_wgrad_gemm( None, # bwd_ln_sm_margin None, # zero_centered_gamma None, # activation + None, # activation_params None, # normalization None, # ub_overlap_ag None, # ub_overlap_rs @@ -1436,7 +1447,11 @@ class LayerNormMLP(TransformerEngineBaseModule): activation : str, default = 'gelu' activation function used. Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', and 'swiglu'. + 'silu', 'swiglu', and 'clamped_swiglu'. + activation_params : dict, default = `None` + Additional parameters for the activation function. + At the moment, only used for 'clamped_swiglu' activation which + supports 'limit' and 'alpha' parameters. init_method : Callable, default = `None` used for initializing FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. @@ -1537,6 +1552,7 @@ def __init__( bias: bool = True, normalization: str = "LayerNorm", activation: str = "gelu", + activation_params: Optional[dict] = None, output_layer_init_method: Optional[Callable] = None, fuse_wgrad_accumulation: bool = False, params_dtype: Optional[torch.dtype] = None, @@ -1564,6 +1580,7 @@ def __init__( assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" self.use_bias = bias self.activation = activation + self.activation_params = activation_params self.return_bias = return_bias self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output @@ -1643,7 +1660,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]: + if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1897,6 +1914,7 @@ def forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.activation, + self.activation_params, self.normalization, self.ub_overlap_ag, self.ub_overlap_rs, @@ -2026,6 +2044,19 @@ def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Ten fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias) fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32 + act_params = self.activation_params or {} + # Default params for clamped_swiglu in Transformer Engine + clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get( + "alpha", 1.702 + ) + + def _clamped_swiglu(x, limit, alpha): + x_glu, x_linear = x.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + y = out_glu * (x_linear + 1) + return y activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), @@ -2040,6 +2071,9 @@ def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Ten * x.chunk(2, -1)[1], "silu": torch.nn.functional.silu, "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "clamped_swiglu": lambda x: _clamped_swiglu( + x, clamped_swiglu_limit, clamped_swiglu_alpha + ), } if self.activation not in activation_map: raise ValueError(f"Unsupported activation in onnx export: {self.activation}") diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 8a032b2f5..4c7599ad8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module): activation : str, default = 'gelu' Type of activation used in MLP block. Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', and 'swiglu'. + 'silu', 'swiglu', and 'clamped_swiglu'. + activation_params : Optional[dict], default = `None` + Additional parameters for the activation function. + At the moment, only used for 'clamped_swiglu' activation which + supports 'limit' and 'alpha' parameters. You can set these as + `activation_params={'limit': 7.0, 'alpha': 1.702}`. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the @@ -310,6 +315,7 @@ def __init__( ub_bulk_wgrad: bool = True, bias: bool = True, activation: str = "gelu", + activation_params: Optional[dict] = None, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", @@ -475,6 +481,7 @@ def __init__( ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, activation=activation, + activation_params=activation_params, normalization=normalization, device=device, name=name + ".layernorm_mlp" if name is not None else None, From f0295f9d9b0b6353dac0accd36a8030b55dbd733 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 30 Oct 2025 13:12:03 -0400 Subject: [PATCH 091/141] CMake to respect MAX_JOBS or NVTE_MAX_JOBS (#2319) fix max jobs Signed-off-by: Phuong Nguyen --- transformer_engine/common/CMakeLists.txt | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e388dd794..e0c42b2d9 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -352,10 +352,10 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") # Number of parallel build jobs -if(ENV{MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") -elseif(ENV{NVTE_BUILD_MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") +if($ENV{MAX_JOBS}) + set(BUILD_JOBS_STR $ENV{MAX_JOBS}) +elseif($ENV{NVTE_BUILD_MAX_JOBS}) + set(BUILD_JOBS_STR $ENV{NVTE_BUILD_MAX_JOBS}) else() set(BUILD_JOBS_STR "max") endif() From 5e8a9a961f5375cd7c098989f618194bbcf4e9cb Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 30 Oct 2025 10:27:42 -0700 Subject: [PATCH 092/141] [JAX] Fix: Skip determinism tests for bprop for all sm >=100 (#2315) * Fix: Skip determinism tests for bprop for all sm >=100 Signed-off-by: Kshitij Lakhani * Add username to TODO Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100+ Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 6 +++--- transformer_engine/jax/cpp_extensions/attention.py | 7 +++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 5b814cb99..a5d73d960 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -378,14 +378,14 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - + # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support if ( - get_device_compute_capability(0) == 100 + get_device_compute_capability(0) >= 100 and self.dropout_prob == 0.1 and self.attn_bias_type is not AttnBiasType.NO_BIAS ): pytest.skip( - "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index db2537c38..c0cb6cda1 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -2739,10 +2739,13 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - if 100 in get_all_device_compute_capability(): + # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on + # sm100+ + compute_capabilities = get_all_device_compute_capability() + if any(x >= 100 for x in compute_capabilities): assert not ( attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, From 490a5f41ada5788bc6dd94ba54ab024e465e0ec6 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 30 Oct 2025 15:32:37 -0400 Subject: [PATCH 093/141] [PyTorch] Fix attention backend and tests for `sm120` (#2320) * Fix attention backend and tests for sm120 Signed-off-by: Kirthi Shankar Sivamani * Disable MLA only for backward Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/attention/test_attention.py | 22 +++++++----- .../attention/dot_product_attention/utils.py | 35 +++++++++++++++++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index b05a0447c..a671f1eec 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -61,8 +61,16 @@ get_available_attention_backends, ) -# Check if hardware supports FP8 +# Check if hardware supports FP8 attention. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8 +device_compute_capability = get_device_compute_capability() +if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)): + fp8_attn_available = False + reason_for_no_fp8_attn = ( + "FP8 attention is not supported for compute capability =" + f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" + ) # Reset RNG seed and states seed = 1234 @@ -1573,8 +1581,7 @@ def _run_transformer_layer( } -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @@ -1736,8 +1743,7 @@ def get_model(dtype, config): @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @@ -1973,8 +1979,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @@ -2302,8 +2307,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ), reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", ) -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) def test_custom_mha_fp8_vs_f16(dtype, model): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 51279bd37..25dc0e96c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -481,6 +481,20 @@ def get_attention_backend( logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0") use_fused_attention = False + if device_compute_capability == (12, 0): + if use_flash_attention: + logger.debug( + "Disabling FlashAttention as FP8 is not supported" + " for compute capability = sm120" + ) + if use_fused_attention: + logger.debug( + "Disabling FusedAttention as FP8 is not supported" + " for compute capability = sm120" + ) + use_flash_attention = False + use_fused_attention = False + # Filter: Return max_logit if return_max_logit: if use_flash_attention: @@ -560,6 +574,20 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + if ( + device_compute_capability == (12, 0) + and (head_dim_qk > 128 or head_dim_qk % 8 != 0) + and is_training + ): + if use_fused_attention: + logger.debug( + "Disabling FusedAttention as MLA for backward pass is not supported for compute" + " capability = sm120 for a head_dim_qk > 128 or head_dim_qk %%8 != 0. Found:" + " head_dim_qk = %s", + head_dim_qk, + ) + use_fused_attention = False + if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 @@ -629,6 +657,13 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False + if device_compute_capability == (12, 0): + if use_fused_attention: + logger.debug( + "Disabling FusedAttention as qkv_format = thd is" + " not supported for compute capability = sm120" + ) + use_fused_attention = False # Filter: Dropout if attention_dropout != 0.0 and use_flash_attention_3: From 0e80c847845be58dbb88f46a0975786ddc823798 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Thu, 30 Oct 2025 20:53:11 +0100 Subject: [PATCH 094/141] [Common] Split cast/gated kernels by scaling mode (#2248) * Separated gated and dequantize kernels Signed-off-by: Oleg Goncharov * Separated quantize, dequantize and gated functions Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed lint issues Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed persistent lint issues Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added missing compute capability 10.0 check for Quantize FP8 TMA kernels Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed the issue which was added again by autofix Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed files description. Completely removed non-identity activations from the NVFP4 transpose test suite Signed-off-by: Oleg Goncharov * Removed unsupported template arguments in NVFP4 quantize Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed undefined symbol error Signed-off-by: Oleg Goncharov * Fixed condition Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> * Fixed CUDA version check Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed arch conditions order Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix Signed-off-by: Oleg Goncharov * Clean up Signed-off-by: Oleg Goncharov * Small fix Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Small fix Signed-off-by: Oleg Goncharov * Fixes per the PR review Signed-off-by: Oleg Goncharov * Fix Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Split quantize helper into two (FWD and BWD) functions Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Moved activation functions from cast.cu. Removed cast.cu from the fast-math compilation list Signed-off-by: Oleg Goncharov * Enabled fast math for activations by default Signed-off-by: Oleg Goncharov * Disabled fast math for activations by default Signed-off-by: Oleg Goncharov --------- Signed-off-by: Oleg Goncharov Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/CMakeLists.txt | 5 +- .../common/activation/activation_template.h | 27 +- transformer_engine/common/activation/gelu.cu | 26 + transformer_engine/common/activation/relu.cu | 26 + .../common/activation/swiglu.cu | 13 + transformer_engine/common/cast/cast.cu | 102 + .../common/cast/core/common.cuh | 97 + .../common/cast/dispatch/dequantize.cuh | 56 + .../common/cast/dispatch/gated.cuh | 161 ++ .../common/cast/dispatch/quantize.cuh | 326 +++ .../common/cast/fp8/dequantize_fp8.cuh | 54 + .../common/cast/fp8/gated_fp8.cuh | 394 +++ .../common/cast/fp8/quantize_fp8.cuh | 580 +++++ .../mxfp8/dequantize_mxfp8.cuh} | 171 +- .../mxfp8/gated_mxfp8.cuh} | 711 +----- .../common/cast/mxfp8/quantize_mxfp8.cuh | 722 ++++++ .../common/cast/nvfp4/core_nvfp4.cuh | 112 + .../common/cast/nvfp4/dequantize_nvfp4.cuh | 111 + .../common/cast/nvfp4/quantize_nvfp4.cuh | 688 ++++++ .../cast/nvfp4/quantize_transpose_nvfp4.cuh | 1287 ++++++++++ transformer_engine/common/util/cast.cu | 201 -- .../common/util/cast_kernels.cuh | 2188 ----------------- transformer_engine/common/util/math.h | 2 + transformer_engine/common/util/ptx.cuh | 191 +- 24 files changed, 5073 insertions(+), 3178 deletions(-) create mode 100644 transformer_engine/common/cast/cast.cu create mode 100644 transformer_engine/common/cast/core/common.cuh create mode 100644 transformer_engine/common/cast/dispatch/dequantize.cuh create mode 100644 transformer_engine/common/cast/dispatch/gated.cuh create mode 100644 transformer_engine/common/cast/dispatch/quantize.cuh create mode 100644 transformer_engine/common/cast/fp8/dequantize_fp8.cuh create mode 100644 transformer_engine/common/cast/fp8/gated_fp8.cuh create mode 100644 transformer_engine/common/cast/fp8/quantize_fp8.cuh rename transformer_engine/common/{util/dequantize_kernels.cuh => cast/mxfp8/dequantize_mxfp8.cuh} (69%) rename transformer_engine/common/{util/cast_gated_kernels.cuh => cast/mxfp8/gated_mxfp8.cuh} (53%) create mode 100644 transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh create mode 100644 transformer_engine/common/cast/nvfp4/core_nvfp4.cuh create mode 100644 transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh create mode 100644 transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh delete mode 100644 transformer_engine/common/util/cast.cu delete mode 100644 transformer_engine/common/util/cast_kernels.cuh diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index e0c42b2d9..62b769c77 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -168,7 +168,7 @@ list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_arch_specific_sources gemm/cutlass_grouped_gemm.cu - util/cast.cu + cast/cast.cu activation/gelu.cu activation/relu.cu activation/swiglu.cu @@ -336,8 +336,7 @@ option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --u if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu activation/relu.cu - activation/swiglu.cu - util/cast.cu) + activation/swiglu.cu) endif() foreach(cuda_source IN LISTS nvte_sources_with_fast_math) diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 1d9a3fb43..7353c3e1d 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -14,26 +14,17 @@ #include #include +#include "../cast/dispatch/gated.cuh" +#include "../cast/dispatch/quantize.cuh" #include "../common.h" -#include "../util/cast_gated_kernels.cuh" -#include "../util/cast_kernels.cuh" -#include "../util/math.h" -#include "../util/vectorized_pointwise.h" namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; constexpr bool IS_ACT = true; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template @@ -42,20 +33,17 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, using namespace detail; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + nullptr, stream); } template void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DGATED = false; - constexpr NVTETensor grad = nullptr; - quantize_gated_helper(grad, input, output, p, stream); + dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template (grad, input, output, p, stream); + dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 4949ba590..4979023ef 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -20,6 +20,19 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; @@ -48,6 +61,19 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index c74fc6eee..c0ef9fd65 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -20,6 +20,19 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; @@ -48,6 +61,19 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index cafc48abb..6957a91e6 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -20,6 +20,19 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu new file mode 100644 index 000000000..1ed46a335 --- /dev/null +++ b/transformer_engine/common/cast/cast.cu @@ -0,0 +1,102 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/multi_stream.h" +#include "../utils.cuh" +#include "dispatch/dequantize.cuh" +#include "dispatch/quantize.cuh" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_fwd_helper(input, output, nullptr, stream); +} + +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; + + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_fwd_helper(input, output, quant_config, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTETensor activation_input = nullptr; + + dispatch::quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); + using namespace transformer_engine; + dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), + stream); +} + +void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, + const NVTEQuantizationConfig quant_configs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + + const size_t num_streams = nvte_get_num_compute_streams(); + + int num_stream_used = std::min(num_streams, num_tensors); + // wait for current stream to finish + NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); + } + + for (int i = 0; i < num_tensors; i++) { + dispatch::quantize_fwd_helper( + inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); + } + + // record events on compute streams + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); + } + // wait for all compute streams to finish + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); + } +} diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh new file mode 100644 index 000000000..b750142f5 --- /dev/null +++ b/transformer_engine/common/cast/core/common.cuh @@ -0,0 +1,97 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file common.cuh + * \brief Common functions in quantize. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace common { +inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % elems_per_block == 0); + return isFullTile; +} + +inline bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr size_t TMA_bytes = 16; + const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + return cols % alignment_requirement == 0; +} + +namespace kernel { + +constexpr size_t THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, + const size_t rows, const size_t cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} +} // namespace kernel + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace common +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh new file mode 100644 index 000000000..b8547915c --- /dev/null +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -0,0 +1,56 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize.cuh + * \brief Dequantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../fp8/dequantize_fp8.cuh" +#include "../mxfp8/dequantize_mxfp8.cuh" +#include "../nvfp4/dequantize_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { + +inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { + if (is_supported_by_CC_100()) { + mxfp8::dequantize(input, output, stream); + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } + break; + } + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh new file mode 100644 index 000000000..4373090b7 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated.cuh + * \brief Gated dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ + +#include + +#include "../../common.h" +#include "../../utils.cuh" +#include "../fp8/gated_fp8.cuh" +#include "../mxfp8/gated_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, + cudaStream_t stream) { + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols, + "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + Tensor dummy_grad_tensor; + fp8::cast_gated_tma(input, dummy_grad_tensor, + output, p, stream); + } else { + fp8::cast_gated_fwd(input, output, p, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + Tensor dummy_grad_tensor; + mxfp8::quantize_gated(input, dummy_grad_tensor, + output, p, stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +template +void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, + NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(grad, "grad"); + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + gated_input.flat_last_dim(), "."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + + NVTE_CHECK(grad.flat_first_dim() == rows, + "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + NVTE_CHECK(grad.flat_last_dim() == cols, + "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols * 2, + "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(gated_input.data.shape == output->data.shape, + "Gated input and output shapes must match. Input shape: ", gated_input.data.shape, + ", output shape: ", output->data.shape, "."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + fp8::cast_gated_tma(gated_input, grad, output, p, + stream); + } else { + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } + NVTE_CHECK(is_supported_by_CC_100(), + "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); + + mxfp8::quantize_gated(gated_input, grad, output, p, + stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } +} +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh new file mode 100644 index 000000000..9f7a4a9b0 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -0,0 +1,326 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize.cuh + * \brief Quantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/vectorized_pointwise.h" +#include "../core/common.cuh" +#include "../fp8/quantize_fp8.cuh" +#include "../mxfp8/quantize_mxfp8.cuh" +#include "../nvfp4/quantize_nvfp4.cuh" +#include "../nvfp4/quantize_transpose_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_fwd_helper(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + Tensor *output_tensor = convertNVTETensorCheck(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + mxfp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +template +void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *grad_tensor = convertNVTETensorCheck(grad); + const Tensor *input_tensor = convertNVTETensor(input); + + Tensor *output_tensor = convertNVTETensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT) { + cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*grad_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = grad_tensor->flat_first_dim(); + int32_t cols = grad_tensor->flat_last_dim(); + auto dtype = grad_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh new file mode 100644 index 000000000..2514758b5 --- /dev/null +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_fp8.cuh + * \brief CUDA kernels to dequantize from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + DequantizeParam p; p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh new file mode 100644 index 000000000..225ef93ed --- /dev/null +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -0,0 +1,394 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated_fp8.cuh + * \brief CUDA kernels to cast to FP8 with gated activations. + */ + +#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_FP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +namespace kernel { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t grad_mem = IS_BWD ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_BWD) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const size_t buff = it % BUFFERS_NUM; + const size_t next_it = it + 1; + if (next_it < ITERATIONS) { + const size_t next_buff = next_it % BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_BWD) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + } + + if constexpr (IS_BWD) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); + act_x = x * s; + if (act_elt <= p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; + } else { + dact_x = 0.0f; + } + } else { + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } + } + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, p) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_BWD) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace kernel + +template +void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, + cudaStream_t stream) { + using namespace kernel; + checkCuDriverContext(stream); + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_BWD) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, typeToNumBits(gated_input.dtype())); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, + typeToNumBits(output->dtype())); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; + + auto kernel = cast_fp8_gated_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), + output->flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated_bwd(const Tensor &input, const Tensor &grad, Tensor *output, ParamOP &p, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh new file mode 100644 index 000000000..efc5015b7 --- /dev/null +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -0,0 +1,580 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_fp8.cuh + * \brief CUDA kernels to quantize to FP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +namespace quantize_2D_kernel { + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + const size_t dbias_offset_Y = blockIdx.y + tid_Y; + const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const size_t dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const size_t chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const size_t buff = iter % FP8_BUFFERS_NUM; + const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const size_t dbias_offset_X = my_column; + const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_2D_kernel + +namespace quantize_1D_kernel { +using namespace quantize_2D_kernel; + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const size_t buff = iter % SHMEM_BUFFERS; + const size_t it_offset = iter * SHMEM_DIM; + + const size_t next_iter = iter + 1; + const size_t next_buff = next_iter % SHMEM_BUFFERS; + const size_t next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_1D_kernel + +template +void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace quantize_1D_kernel; + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + using namespace quantize_2D_kernel; + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + NVTE_CHECK_CUDA(cudaGetLastError()); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +namespace detail { +using Empty = transformer_engine::Empty; +__device__ inline float identity(float value, const Empty &) { return value; } +} // namespace detail + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_1D_kernel; + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + if (!IS_DBIAS && !IS_DACT) { + if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 + quantize_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (common::dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 (+dAct) + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } + } else { + if (IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + + " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } +} + +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh similarity index 69% rename from transformer_engine/common/util/dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 9f70ce4cd..fb43fce96 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -4,36 +4,27 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file dequantize_kernels.cuh - * \brief CUDA kernels to cast from MXFP8. +/*! \file dequantize_mxfp8.cuh + * \brief CUDA kernels to dequantize from MXFP8. */ -#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ #include #include #include -#include - -#include -#include -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transformer_engine.h" -#include "transformer_engine/transpose.h" +#include -namespace transformer_engine { +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" -namespace dequantization { +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace dequantize_kernel { constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -228,29 +219,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace dequantize_kernel -void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), nullptr, - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) -} - -void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace dequantize_kernel; bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); checkCuDriverContext(stream); @@ -334,113 +306,8 @@ void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); } - -#if CUDA_VERSION >= 12080 -template -__global__ void __launch_bounds__(512) - dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride) { - const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; - const size_t x = thread_idx % M; - const size_t y = thread_idx / M; - - union fp4vec { - uint64_t vec; - fp4e2m1x4 small_vec[4]; - }; - using OVec = Vec; - const uint64_t *const input_vectorized = reinterpret_cast(input); - OVec *output_vec = reinterpret_cast(output); - - const size_t my_index = x + y * M; - const size_t my_scale_index = x + y * scale_stride; - const size_t my_output_index = (x + y * M) * 4; - fp4vec value; - value.vec = input_vectorized[my_index]; - fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); - float final_scale = static_cast(scale) * amax * factor_inv; -#pragma unroll - for (int i = 0; i < 4; i++) { - float4 current = static_cast(value.small_vec[i]); - OVec out; - out.data.elt[0] = static_cast(current.x * final_scale); - out.data.elt[1] = static_cast(current.y * final_scale); - out.data.elt[2] = static_cast(current.z * final_scale); - out.data.elt[3] = static_cast(current.w * final_scale); - output_vec[my_output_index + i] = out; - } -} -#endif // CUDA_VERSION - -void fp4_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { -#if CUDA_VERSION >= 12080 - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output"); - NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); - NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - constexpr int FP4_BLOCK_SIZE = 16; - const size_t N = input.flat_first_dim(); - const size_t M = input.flat_last_dim(); - - NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", - FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); - - const size_t Mread = M / FP4_BLOCK_SIZE; - const size_t total = N * Mread; - const size_t threads = 512; - const size_t blocks = DIVUP(total, threads); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, - input.scale_inv.shape.back());); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -#else - NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); -#endif // CUDA_VERSION >= 12080 -} - -} // namespace dequantization - -namespace detail { - -void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - switch (input.scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - dequantization::fp8_dequantize(input, output, stream); - break; - } - case NVTE_MXFP8_1D_SCALING: { - if (is_supported_by_CC_100()) { - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - break; - } - case NVTE_NVFP4_1D_SCALING: { - dequantization::fp4_dequantize(input, output, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } -} - -} // namespace detail - +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh similarity index 53% rename from transformer_engine/common/util/cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 93086bd82..4f0e1b80f 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -4,280 +4,27 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file cast_gated_kernels.cuh - * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. +/*! \file gated_mxfp8.cuh + * \brief CUDA kernels to cast to MXFP8 with gated activations. */ -#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ #include #include #include -#include -#include +#include -#include - -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" namespace transformer_engine { - -namespace gated_kernels { - -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 512; -constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 -constexpr size_t BUFFERS_NUM = 2; -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); - -__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act, - const __grid_constant__ CUtensorMap tensor_map_output_gate, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols, - const ParamOP p) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - - constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t in_transaction_size = buff_elems * sizeof(IType); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); - const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - // Prefetch data of the first stage - - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, - TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, - chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } else { - copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } - -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const size_t buff = it % BUFFERS_NUM; - const size_t next_it = it + 1; - if (next_it < ITERATIONS) { - const size_t next_buff = next_it % BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3( - &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, - &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, - in_transaction_size, &mbar[next_it], is_master_thread); - } else { - copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, - chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, - &mbar[next_it], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_sh_curr = out_act_sh + buff * buff_elems; - OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; -#pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - bool dgate_elt = true; // gating is ideally an identity function - if constexpr (std::is_same::value) { - // In case of GPT OSS, clamp the activation and gate values - dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp - gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; - } - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - - const float x = act_elt; - float act_x; - float dact_x; - if constexpr (std::is_same::value) { - const float x = min(act_elt, p.limit); - const float s = sigmoidf(p.alpha * x); - act_x = x * s; - if (act_elt <= p.limit) { - dact_x = s + s * (1 - s) * p.alpha * x; - } else { - dact_x = 0.0f; - } - } else { - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, p); - dact_x = DActOP(x, p); - } - } - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; - - out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, p) * gate_elt; - out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // dGeLU - ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, - chunk_it_offset_y, - reinterpret_cast(out_act_sh_curr)); - - if constexpr (IS_DGATED) { - // dGate - ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_sh_curr)); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -namespace mxfp8_kernel { +namespace dispatch { +namespace mxfp8 { +namespace gated_kernel { constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; @@ -302,20 +49,21 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 -template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise, const ParamOP p) { + quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -385,14 +133,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -427,7 +175,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) int parity = 0; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, @@ -454,7 +202,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], @@ -497,7 +245,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; } - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; @@ -526,20 +274,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_elt = static_cast(static_cast(after_gate_elt)); } } after_act_colwise[i] = after_act_elt; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_colwise[i] = after_gate_elt; } // Cache computed activations to avoid computing them again in the 2nd pass along another dimension if constexpr (IS_CACHED_ACT_OP) { cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); } } @@ -549,7 +297,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (!out_of_bounds) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); } } @@ -578,7 +326,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // All threads read the reduced amax (ACT) thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { // Make sure the previous read of the ACT values has been completed, // so the data are not rewritten __syncthreads(); @@ -622,7 +370,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); float block_scale_inverse_gate; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); @@ -639,7 +387,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { const size_t shmem_offset_elt = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { OType2 out_pair; ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, @@ -685,7 +433,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Load cached elements in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); } // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) @@ -695,7 +443,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); } } @@ -705,7 +453,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], in_cached_act[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], in_cached_gate[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); @@ -717,7 +465,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (!std::is_same_v) { thread_amax_act = static_cast( __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = static_cast( __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); } @@ -735,7 +483,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) in_act.load_from(&in_act_sh[shmem_offset_rowwise]); in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); } @@ -753,7 +501,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; } - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; @@ -786,7 +534,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_elt = static_cast(static_cast(after_gate_elt)); } } @@ -796,7 +544,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); } } @@ -822,7 +570,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_scale_inverse_gate; ptx::floatx2 block_scale_inverse_2x_gate; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; @@ -853,7 +601,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { IType2 in_gate; OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); @@ -873,7 +621,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); } } @@ -894,7 +642,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); @@ -904,7 +652,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); @@ -920,94 +668,13 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -} // namespace mxfp8_kernel +} // namespace gated_kernel -template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, +void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, cudaStream_t stream) { - checkCuDriverContext(stream); - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act{}; - alignas(64) CUtensorMap tensor_map_output_gate{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { + using namespace gated_kernel; checkCuDriverContext(stream); const bool USE_ROWWISE_SCALING = output->has_data(); @@ -1031,17 +698,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; - constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; @@ -1073,7 +734,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out constexpr size_t input_type_bit_size = TypeInfo::size; constexpr size_t output_type_bit_size = TypeInfo::size; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, input_type_bit_size); } @@ -1110,238 +771,68 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; switch (scaling_type) { - case ScalingType::ROWWISE: + case ScalingType::ROWWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::COLWISE: + } + case ScalingType::COLWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::BIDIMENSIONAL: + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise, p); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_gated(const Tensor &input, Tensor *output, ParamOP p, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, - "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), p, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP p, - cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, - "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match. Input shape: ", input.data.shape, - ", output shape: ", output->data.shape, "."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), p, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP p, - cudaStream_t stream) { - constexpr bool allow_empty = false; - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", allow_empty); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - if constexpr (IS_DGATED) { - CheckInputTensor(grad, "grad"); - NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); - NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); - } - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - bool is_fp8_rowwise_output = true; - bool is_fp8_colwise_output = true; - if (output->has_data()) { - is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - if (output->has_columnwise_data()) { - is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - - const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; - - if (is_delayed_tensor_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, p, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, p, stream); - } else { - cast_gated(gated_input, output, p, stream); - } - } - } else if (is_mxfp8_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, p, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } - } else { - NVTE_ERROR("Not supported scaling mode"); - } -} -} // namespace gated_kernels - -namespace detail { - -template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - ParamOP p, cudaStream_t stream) { - using namespace gated_kernels; - Tensor grad_empty_tensor; - const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); - Tensor *output_tensor = convertNVTETensorCheck(output); - - if (is_supported_by_CC_100()) { - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, p, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, p, - stream); - } else { - cast_gated(gated_input_tensor, output_tensor, p, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } -} // namespace detail +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh new file mode 100644 index 000000000..5505de605 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -0,0 +1,722 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_mxfp8.cuh + * \brief CUDA kernels to quantize to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace quantize_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_kernel; + checkCuDriverContext(stream); + + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::COLWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + } + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh new file mode 100644 index 000000000..cff846490 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -0,0 +1,112 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file core_nvfp4.cuh + * \brief Core functions used in NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ + +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../util/curanddx.hpp" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif // FP4_TYPE_SUPPORTED + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +using nvfp4_scale_t = fp8e4m3; + +namespace quantization_and_transposition_SF { +#if FP4_TYPE_SUPPORTED +// Used in transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantization_and_transposition_SF + +namespace quantization_SF { +#if FP4_TYPE_SUPPORTED +// Used in non-transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantization_SF + +namespace core { + +#if FP4_TYPE_SUPPORTED +using namespace ptx; + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> &rng, + // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + random_uint4 = rng.generate4(); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace core +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh new file mode 100644 index 000000000..bf7b535be --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -0,0 +1,111 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_nvfp4.cuh + * \brief CUDA kernels to dequantize from NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif // FP4_TYPE_SUPPORTED + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace dequantize_kernel { +#if FP4_TYPE_SUPPORTED +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // FP4_TYPE_SUPPORTED +} // namespace dequantize_kernel + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace dequantize_kernel; + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // FP4_TYPE_SUPPORTED +} +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh new file mode 100644 index 000000000..83ad8fd40 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -0,0 +1,688 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_kernel { + +using namespace ptx; +using namespace quantization_SF; +using namespace core; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh new file mode 100644 index 000000000..7322bf265 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -0,0 +1,1287 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; + +// Each call generates 4x uint32_t random numbers +constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + // NEW: 2D Block-based scaling constants + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile + constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + const size_t warp_id = threadIdx.x / 32; + const size_t lane_id = threadIdx.x % 32; + float thread_amax = 0.0f; + const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + + // Helper function for warp reduction + auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { +#pragma unroll + for (int delta = 8; delta >= 1; delta /= 2) { + float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); + thread_amax = fmaxf(thread_amax, other_amax); + } + return thread_amax; + }; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + +#pragma unroll + for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + const size_t block_in_tile_y = block_iter; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + for (int elem = 0; elem < BLOCK_DIM; elem += 2) { + const size_t elem_0_row = block_iter * BLOCK_DIM + elem; + const size_t elem_1_row = elem_0_row + 1; + const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + const size_t elem_1_col = elem_0_col; + + const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; + const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; + + IType2 val_2x; + val_2x.x = in_sh[shmem_offset_0]; + val_2x.y = in_sh[shmem_offset_1]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); + } + + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else { + for (int elem = 0; elem < BLOCK_DIM; ++elem) { + const size_t elem_row = block_iter * BLOCK_DIM + elem; + const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + + // Bounds checking + const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); + const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); + if (!row_out_of_bounds && !col_out_of_bounds) { + const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; + float elt = static_cast(in_sh[shmem_offset]); + + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset] = static_cast(elt); + } + + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + } + // Warp reduction to get block amax + block_amax = warp_reduce_amax(thread_amax, block_in_warp); + + if (lane_id == 0 || lane_id == 16) { + block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; + } + } + + // sync thread to ensure block_amax_matrix is done storing + __syncthreads(); + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 3. Scale elements + + // Load data in + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + } + } else { + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = tid_X_rowwise; + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + } + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_kernel + +template +void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_kernel; + using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + // TODO(Frank): Is there a better way to do this? + bool return_transpose = output->has_columnwise_data(); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + if (return_transpose) { + NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ diff --git a/transformer_engine/common/util/cast.cu b/transformer_engine/common/util/cast.cu deleted file mode 100644 index 107965d34..000000000 --- a/transformer_engine/common/util/cast.cu +++ /dev/null @@ -1,201 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/multi_stream.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "cast_kernels.cuh" -#include "dequantize_kernels.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transpose.h" - -void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - detail::quantize_helper(input, grad, output, dbias, - workspace, nullptr, stream); -} - -void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_noop); - using namespace transformer_engine; - - // Create config with noop tensor - QuantizationConfig quant_config; - quant_config.noop_tensor = noop; - - nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); -} - -void nvte_quantize_v2(const NVTETensor input, NVTETensor output, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_v2); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - detail::quantize_helper( - input, grad, output, dbias, workspace, quant_config, stream); -} - -void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, - NVTETensor workspace, cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr const NVTETensor activation_input = nullptr; - - detail::quantize_helper( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dsilu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_drelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dqgelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, - NVTETensor output, NVTETensor dbias, NVTETensor workspace, - cudaStream_t stream) { - NVTE_API_CALL(nvte_quantize_dbias_dsrelu); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = true; - constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; - - detail::quantize_helper>( - activation_input, input, output, dbias, workspace, nullptr, stream); -} - -void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { - NVTE_API_CALL(nvte_dequantize); - using namespace transformer_engine; - detail::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); -} - -void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, - const NVTEQuantizationConfig quant_configs, - const size_t num_tensors, cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_tensor_quantize); - using namespace transformer_engine; - - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; - constexpr bool IS_ACT = false; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - const size_t num_streams = nvte_get_num_compute_streams(); - - int num_stream_used = std::min(num_streams, num_tensors); - // wait for current stream to finish - NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); - } - - for (int i = 0; i < num_tensors; i++) { - detail::quantize_helper( - inputs[i], grad, outputs[i], dbias, workspace, nullptr, - detail::get_compute_stream(i % num_streams)); - } - - // record events on compute streams - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA( - cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); - } - // wait for all compute streams to finish - for (int s = 0; s < num_stream_used; s++) { - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); - } -} diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh deleted file mode 100644 index b0498602b..000000000 --- a/transformer_engine/common/util/cast_kernels.cuh +++ /dev/null @@ -1,2188 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file cast_kernels.cuh - * \brief CUDA kernels to cast to/from FP8/MXFP8. - */ - -#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ - -#include -#include -#include -#include - -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "nvfp4_transpose.cuh" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { - -namespace mxfp8_kernel { - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 32; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; - static_assert(BUFF_DIM_Y == 32); - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - - OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } - - float block_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - if (rowwise_scale_is_within_bounds) { - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - - parity ^= 1; - - if constexpr (IS_DBIAS) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); - - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - - const int shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; - } - } - const int dbias_stride = cols; - const int dbias_offset_Y = blockIdx.y; - const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace mxfp8_kernel - -namespace nvfp4_kernel { - -using namespace ptx; - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 16; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t BUFF_DIM_Y = 32; - -constexpr size_t PACK_SIZE = 8; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 - -// Compute per-block E4M3 encoding/decoding scaling factor -__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { - constexpr float rcp_6f = 1.0f / 6.0f; - // const float S_dec_b = block_amax * rcp_6f; - // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - // return S_dec_b_fp8; - return static_cast(block_amax * rcp_6f * S_enc); -} - -#define DIRECT_SCALING_FACTORS_STORE 1 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, - const float *noop, float *const amax_ptr, - const float *const nvfp4_second_stage_scale_ptr, const size_t rows, - const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool ROWWISE_SCALING = true; - constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = - (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); - - using IType2 = typename ptx::FPx2; - - if constexpr (!COMPUTE_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; - constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; - - static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && - "Number of buffer rows must be greater or equal to the size of the columwise " - "scaling block\0"); - static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); - static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && - "Number of buffer rows must be greater or equal to the number of rowwise " - "processing threads in Y dimension\0"); - - constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size - constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - - constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; - // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of - // // threads to process one row in a single iteration - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * CHUNK_DIM_X; - const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const int tid_Y_colwise = 0; - const int tid_X_colwise = threadIdx.x; - - const int thread_offset_Y_rowwise = tid_Y_rowwise; - const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const int thread_offset_Y_colwise = tid_Y_colwise; - const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements - - const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const int col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; - const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_nvfp4 = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_mxfp8 = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t buff_size_nvfp4_scales = - CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); - - constexpr size_t in_mem = buff_size_aligned_in; - - constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); - constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); - constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); - constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); - fp8e4m3 *out_rowwise_scales_sh = - reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); - e8m0_t *out_colwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = - (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); - - float thread_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const int buff = stage % BUFFS_NUM; - const int next_stage = stage + 1; - const int stage_offset_Y = stage * BUFF_DIM_Y; - - const int buff_offset_in = buff * BUFF_IN_DIM; - const int buff_offset_out = buff * BUFF_OUT_DIM; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const int next_buff = next_stage % BUFFS_NUM; - const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const int global_offset_Y = block_offset_Y + next_stage_offset_Y; - const int global_offset_X = block_offset_X; - const int next_buff_offset = next_buff * BUFF_IN_DIM; - - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], 0); - - float block_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; - - block_amax = 0.0f; - float in_compute_colwise[SCALE_DIM_Y]; - IType in_colwise_IType[SCALE_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType block_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); - } - block_amax = static_cast(block_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); - - const int global_scales_offset_Y = scales_offset_Y_colwise + stage; - const int global_scales_offset_X = scales_offset_X_colwise; - const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - if (colwise_scale_is_within_bounds) { - scales_colwise_e8m0[scale_idx] = biased_exponent; - } - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; - out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; -#pragma unroll - for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { - const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; - - const int shmem_offset_base_rowwise_in = - buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; - const int shmem_offset_base_rowwise_out = - buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; - - const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; - - block_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find NVFP4-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = - (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E4M3 scaling factor - const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); - -#if DIRECT_SCALING_FACTORS_STORE - // Check boundaries - if (rowwise_scale_is_within_bounds) { - const int scales_offset_Y = - scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; - const int scales_offset_X = scales_offset_X_rowwise; - const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; - scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; - } -#else - const int shmem_scales_offset_Y = - stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; - const int shmem_scales_offset_X = tid_X_rowwise; - const int scale_idx = - shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; - out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; -#endif - // Compute "correct" per-block encoding scaling factor - const float block_scale_inverse = - __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 - -// 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; // Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - in01 = in_IType[w].data.elt[2 * e]; - in23 = in_IType[w].data.elt[2 * e + 1]; - } else if constexpr (IS_CACHED_ACT_OP) { - in01.x = in_cached[w].data.elt[4 * e]; - in01.y = in_cached[w].data.elt[4 * e + 1]; - in23.x = in_cached[w].data.elt[4 * e + 2]; - in23.y = in_cached[w].data.elt[4 * e + 3]; - } else { - const int j = w * PACK_SIZE + 4 * e; - in01.x = in_compute_rowwise[j]; - in01.y = in_compute_rowwise[j + 1]; - in23.x = in_compute_rowwise[j + 2]; - in23.y = in_compute_rowwise[j + 3]; - } - fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); - ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); - } - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); - } - } - } - - __builtin_assume(thread_amax >= 0); - __builtin_assume(block_amax >= 0); - thread_amax = fmaxf(thread_amax, block_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; - const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - -#if !DIRECT_SCALING_FACTORS_STORE - // Vectorized store of scaling factors. - // Each thread stores multiple scaling factors in one store instruction. - if constexpr (ROWWISE_SCALING) { - // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; - const int scale_idx_global = - scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; - const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; - - if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && - (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { - using ScalesVec_t = Vec; - const ScalesVec_t &scales = - *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); - scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); - } - } -#endif - - float chunk_amax = 0.0f; - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - chunk_amax = reduce_max(thread_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, chunk_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace nvfp4_kernel - -constexpr size_t FP8_CHUNK_DIM_Y = 128; -constexpr size_t FP8_CHUNK_DIM_X = 128; -constexpr size_t FP8_THREADS_PER_CHUNK = 128; -constexpr size_t FP8_BUFFERS_NUM = 2; -constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); - -constexpr size_t FP8_BUFFER_DIM_Y = 16; -constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 -constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 - -constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); - -template -__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output, - float *const dbias_workspace, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, - const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - const size_t dbias_offset_Y = blockIdx.y + tid_Y; - const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; - const bool col_out_of_bounds = my_column >= cols; - const size_t dbias_stride = cols; - - float partial_dbias = 0.f; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - - constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const size_t chunk_stage_offset_X = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); - } - } - -#pragma unroll - for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const size_t buff = iter % FP8_BUFFERS_NUM; - const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; - if (next_iter < FP8_ITERATIONS) { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); - } - } - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = row >= rows; - const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if constexpr (IS_DACT) { - if (!out_of_bounds) { - partial_dbias += elt; - } - } else { - // If no activation, elt is 0 so we can safely do this - partial_dbias += elt; - } - } - __builtin_assume(amax >= 0); - if (IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); - } - out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - - if constexpr (IS_DBIAS) { - const size_t dbias_offset_X = my_column; - const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t CHUNKS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; -constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; -constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; -constexpr size_t CHUNKS_PER_ITERATION = 32; -constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; -constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; -constexpr size_t SHMEM_BUFFERS = 2; -static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); - -template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; - const IType *input = input_ptr + block_offset; - OType *output = output_ptr + block_offset; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - - constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; ++iter) { - const size_t buff = iter % SHMEM_BUFFERS; - const size_t it_offset = iter * SHMEM_DIM; - - const size_t next_iter = iter + 1; - const size_t next_buff = next_iter % SHMEM_BUFFERS; - const size_t next_iter_offset = next_iter * SHMEM_DIM; - - if (next_iter < ITERATIONS) { - copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, - &(mbar[next_iter]), is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; - float elt = static_cast(in_sh[buff][shmem_offset]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(elt)); - out_sh[buff][shmem_offset] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - ptx::cp_async_bulk_tensor_1d_shared_to_global( - reinterpret_cast(output + it_offset), - reinterpret_cast(&out_sh[buff]), transaction_size_OUT); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read<1>(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; -template -__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, - const size_t rows, const size_t cols) { - using ComputeVec = Vec; - using OutputVec = Vec; - - const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - if (thread_id * nvec >= cols) { - return; - } - - const float *const thread_in_base = dbias_partial + thread_id * nvec; - OType *const thread_out_base = dbias_output + thread_id * nvec; - - ComputeVec ldg_vec; - ComputeVec acc_vec; - acc_vec.clear(); - for (int i = 0; i < rows; ++i) { - ldg_vec.load_from(thread_in_base + i * cols); -#pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } - - OutputVec stg_vec; -#pragma unroll - for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); - } - stg_vec.store_to(thread_out_base); -} - -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream) { - constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 - constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); - - NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); - const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); - - reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { - const size_t N = product(input.data.shape); - - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - const size_t chunks = DIVUP(N, CHUNK_SIZE); - const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - const float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(THREADS_PER_BLOCK); - const dim3 grid(blocks); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - const IType *input_ptr = reinterpret_cast(input.data.dptr); - OType *output_ptr = reinterpret_cast(output->data.dptr); - - cast_fp8_1D_kernel<<>>( - input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) - ); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, - Tensor *workspace, cudaStream_t stream) { - checkCuDriverContext(stream); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); - const size_t blocks_Y = chunks_Y; - const size_t blocks_X = chunks_X; - - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(FP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - } - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); - - cast_fp8_2D_kernel - <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, - workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, - const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - using namespace mxfp8_kernel; - checkCuDriverContext(stream); - - bool use_rowwise_scaling = output->has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - - if (use_rowwise_scaling) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - - constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; - - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -// This kernel supports only two scaling cases: -// 1. r16c0 - Rowwise NVFP4 -// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 -template -void nvfp4_quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { - using namespace nvfp4_kernel; - using namespace ptx; - checkCuDriverContext(stream); - - NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - bool use_colwise_scaling = output->has_columnwise_data(); - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - constexpr size_t CHUNK_DIM_Y = 128; - constexpr size_t CHUNK_DIM_X = 128; - constexpr size_t THREADS_PER_CHUNK = 128; - - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = output->scale_inv.shape[1]; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); - e8m0_t *const scales_colwise_e8m0_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - - const ScalingType scaling_type = - use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const float *const nvfp4_second_stage_scale_ptr = - reinterpret_cast(output->scale.dptr); - - // Output data type is only required for the column-wise MXFP8 scaling. - // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work - const DType output_data_type = - use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, nvfp4_kernel::BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, sizeof(IType) * 8); - - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - nvfp4_kernel::BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); - } - - constexpr size_t buff_elems = nvfp4_kernel::BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = nvfp4_kernel::BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_nvfp4 = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out_mxfp8 = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_nvfp4_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); - constexpr size_t buff_size_mxfp8_scales = - (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); - - constexpr size_t in_mem = buff_size_aligned_in; - - const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; - const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; - - const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; - const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; - - const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + - out_rowwise_scales_mem + out_colwise_scales_mem + - TMA_SHMEM_ALIGNMENT; - - const size_t dshmem_size = in_mem + out_mem; - - switch (scaling_type) { - case ScalingType::ROWWISE: - cudaFuncSetAttribute( - cast_nvfp4_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - - cast_nvfp4_kernel - <<>>( - tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, - scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, - nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - break; - case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( - cast_nvfp4_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - - cast_nvfp4_kernel - <<>>( - tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, - scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, - nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - break; - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace detail { - -using Empty = transformer_engine::Empty; - -__device__ inline float identity(float value, const Empty &) { return value; } - -struct DequantizeParam { - const float *scale_inv; -}; - -__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); -} - -} // namespace detail - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(noop->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream) { - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; - const size_t N = product(input->data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input->data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace { - -static bool is_full_tile_1D_tensor(const Tensor *const t) { - const size_t N = product(t->data.shape); - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - return isFullTile; -} - -bool dimensions_supported_by_TMA(const Tensor *const t) { - const size_t cols = t->flat_last_dim(); - constexpr size_t TMA_bytes = 16; - const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); - return cols % alignment_requirement == 0; -} - -} // namespace - -// Supported by the Arch >= 10.0 -template -void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 - cast_fp8_1D(input, output, stream); - } else { - // Unaligned - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 (+dAct) - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } else { - // Unaligned - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - } else { - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -// Supported by the Arch < 10.0 -template -void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { - // zhongboz: should we just ignore IS_ACT here? - NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); - } - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -template -void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, - Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { - // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } -} - -namespace detail { - -template -void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, - NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - const Tensor *input_tensor; - const Tensor *activation_input_tensor; - if constexpr (IS_DBIAS || IS_DACT) { - // backward - input is incoming gradient - input_tensor = convertNVTETensorCheck(grad); - activation_input_tensor = convertNVTETensor(input); - } else { - // forward = input is activation input - input_tensor = convertNVTETensorCheck(input); - activation_input_tensor = nullptr; - } - auto output_tensor = convertNVTETensorCheck(output); - auto dbias_tensor = convertNVTETensor(dbias); - auto workspace_tensor = convertNVTETensor(workspace); - - // Quantization config - QuantizationConfig quant_config_cpp; - if (quant_config != nullptr) { - quant_config_cpp = *reinterpret_cast(quant_config); - } - - // Noop flag - Tensor dummy_tensor; - Tensor *noop_tensor = &dummy_tensor; - if (quant_config_cpp.noop_tensor != nullptr) { - noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); - } - - // Check for unsupported options - if (quant_config_cpp.stochastic_rounding) { - NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, - "Stochastic rounding is only supported for NVFP4 quantization."); - } - - // Dispatch to quantization kernel depending on data format - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } - case NVTE_NVFP4_1D_SCALING: { - // Check tensors - CheckNoopTensor(*noop_tensor, "cast_noop"); - CheckInputTensor(*input_tensor, "input"); - CheckOutputTensor(*output_tensor, "output", false); - - // Choose kernel - int32_t rows = input_tensor->flat_first_dim(); - int32_t cols = input_tensor->flat_last_dim(); - auto dtype = input_tensor->dtype(); - bool use_optimized_kernel = dtype == DType::kBFloat16 && rows % 32 == 0 && cols % 32 == 0 && - output_tensor->has_data(); - - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { - if (quant_config_cpp.nvfp4_2d_quantization) { - nvfp4_quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } else { - nvfp4_quantize_transpose( - *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); - } - } else { - auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax - : output_tensor->columnwise_amax; - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_NVFP4_1D_SCALING for " - "2D quantization"); - quantize_transpose_vector_blockwise_fp4( - /*input=*/input_tensor->data, /*global_amax=*/global_amax, - /*scale_inv=*/output_tensor->scale_inv, - /*scale_inv_t=*/output_tensor->columnwise_scale_inv, - /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, - /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), - /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, - /*swizzled_scale=*/false, - /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, - /*rng_state=*/quant_config_cpp.rng_state, - /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); - } - break; - } - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor->data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; - float epsilon = quant_config_cpp.amax_epsilon; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT); - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor->data, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } -} - -} // namespace detail -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2f20817fb..005a60067 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -36,6 +36,8 @@ __device__ inline OType sigmoid(const IType val, const Empty&) { return 1.f / (1.f + expf(-cval)); } +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + template __device__ inline OType dsigmoid(const IType val, const Empty& e) { const float cval = val; diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index aeac2b4a2..6605d9cad 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -449,13 +449,12 @@ static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); -#if CUDA_VERSION >= 12080 +#if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; using fp4e2m1x2 = __nv_fp4x2_e2m1; using fp4e2m1x4 = __nv_fp4x4_e2m1; static_assert(sizeof(fp4e2m1x2) == 1); static_assert(sizeof(fp4e2m1x4) == 2); -#endif // CUDA_VERSION >= 12080 // When converting to .e2m1x2 data formats, the destination operand d has .b8 type. // When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, @@ -464,7 +463,6 @@ static_assert(sizeof(fp4e2m1x4) == 2); // from input b is stored in the lower 4 bits of d. // SIMD like "Fused" cast + multiplication (x4) -#if CUDA_VERSION >= 12080 template __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, const float scale) { @@ -474,7 +472,192 @@ __device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, cons const float x3 = static_cast(in23.y) * scale; out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); } -#endif // CUDA_VERSION >= 12080 + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} +#endif // FP4_TYPE_SUPPORTED // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, From 26370b117169aec87df9e86f90814a4faabbcc09 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 30 Oct 2025 15:50:16 -0700 Subject: [PATCH 095/141] [PyT] Bump the min version expected to supported FP8 current scaling determinism on Blackwell (#2316) * Bump the min version expected to supported FP8 cs det on Blackwell Signed-off-by: Kshitij Lakhani * Disable fused attn for cudnn < 9.14 for FP8 CS. Disable fused attn for cudnn < 9.18 for FP8 deterministic CS Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../attention/dot_product_attention/utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 25dc0e96c..7d4a4f86d 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -477,9 +477,21 @@ def get_attention_backend( if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False - elif cudnn_version < (9, 14, 0): - logger.debug("Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0") - use_fused_attention = False + # TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling + # determinism for Blackwell + else: + if cudnn_version < (9, 14, 0): + logger.debug( + "Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0" + ) + use_fused_attention = False + else: + if deterministic and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for FP8 current scaling requiring determinism" + " with cuDNN < 9.18.0" + ) + use_fused_attention = False if device_compute_capability == (12, 0): if use_flash_attention: From 1269b2e209c392d41d81f12391cdabc0d5a132fd Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 30 Oct 2025 16:45:44 -0700 Subject: [PATCH 096/141] [JAX] Ensure JAX reference impl uses an accurate backend in our tests (#2322) Ensure JAX reference impl uses an accurate backend Signed-off-by: Jeremy Berchtold --- qa/L1_jax_distributed_unittest/test.sh | 3 ++- qa/L2_jax_distributed_unittest/test.sh | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 270f0df15..42b70a28e 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -8,5 +8,6 @@ set -xe : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate. +XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh diff --git a/qa/L2_jax_distributed_unittest/test.sh b/qa/L2_jax_distributed_unittest/test.sh index 0b7372650..de5624a59 100644 --- a/qa/L2_jax_distributed_unittest/test.sh +++ b/qa/L2_jax_distributed_unittest/test.sh @@ -8,4 +8,5 @@ set -xe : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate. +XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* From 006670de2f518022ff2a625857f626137a764266 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 31 Oct 2025 08:36:03 -0700 Subject: [PATCH 097/141] [JAX] Fix mesh resource requirement when no mesh (#2307) * Fix mesh resource requirement when no mesh Signed-off-by: Jeremy Berchtold * do not require meshresource if all axes are manual axes Signed-off-by: Jeremy Berchtold * remove abstract_mesh is None check Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/sharding.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index adb67e358..7f204e768 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -75,6 +75,16 @@ def get_sharding_map_logic_axis_to_mesh_axis(): """ Generate a dict to map logical axes to mesh axes. """ + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + if mesh is None or mesh.empty: + # If no mesh is defined, return an empty dict and do not require a MeshResource context to be present + return {} + + abstract_mesh = get_abstract_mesh() + if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): + # If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present + return {} + gsr = global_mesh_resource() is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 From e7227af98070ebfcdb08b7f0a99bb87abe7b8532 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Fri, 31 Oct 2025 17:15:21 +0100 Subject: [PATCH 098/141] [Common] Deleted unused header (#2324) Deleted unused header Signed-off-by: Oleg Goncharov --- .../common/util/nvfp4_transpose.cuh | 1514 ----------------- 1 file changed, 1514 deletions(-) delete mode 100644 transformer_engine/common/util/nvfp4_transpose.cuh diff --git a/transformer_engine/common/util/nvfp4_transpose.cuh b/transformer_engine/common/util/nvfp4_transpose.cuh deleted file mode 100644 index 629520aeb..000000000 --- a/transformer_engine/common/util/nvfp4_transpose.cuh +++ /dev/null @@ -1,1514 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file nvfp4_transpose.cuh - * \brief CUDA kernels to cast to NVFP4 and transpose. - */ - -#ifndef TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ -#define TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ - -#include -#include -#include - -#if FP4_TYPE_SUPPORTED -#include -#endif // FP4_TYPE_SUPPORTED -#include - -#include "../common.h" -#include "../utils.cuh" -#include "curanddx.hpp" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" - -namespace transformer_engine { - -#if FP4_TYPE_SUPPORTED -namespace nvfp4_transpose { - -using namespace ptx; -using nvfp4_scale_t = fp8e4m3; - -constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) - -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_NUM = 128; - -constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; -constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; - -constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; -constexpr size_t RNG_GENS_PER_THREAD = - SCALES_PER_THREAD / 4; // Each call generates 4x uint32_t random numbers - -constexpr size_t TILE_DIM_Y = 32; -constexpr size_t TILE_DIM_X = 128; - -// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D -constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; -constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 - -constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; -constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; -constexpr size_t STAGES = TILES_Y * TILES_X; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; -constexpr size_t BUFF_DIM_X = TILE_DIM_X; -constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; -constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; - -// Input buffer (BF16) -constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; -constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; -constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; - -// Output buffer (NVFP4) -constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; -constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; -constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; - -// Output transpose buffer (NVFP4) -constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; -constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; -constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; - -// Manual swizzling parameters to reduce SHMEM bank conflicts -constexpr size_t PACK_SIZE = 8; -constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; - -constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; -constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 -constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 - -constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 -constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; -constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; - -static_assert(BUFF_DIM_Y >= SCALE_DIM && - "Number of buffer rows must be greater or equal to the size of the columwise " - "scaling block\0"); -static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); -static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && - "Number of buffer rows must be greater or equal to the number of rowwise " - "processing threads in Y dimension\0"); - -// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 - -// Compute per-block E4M3 encoding/decoding scaling factor -__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, - const float S_enc) { - // constexpr float rcp_6f = 1.0f / 6.0f; - // const float S_dec_b = block_amax * rcp_6f; - // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - // return S_dec_b_fp8; - // NOTE: Divide by 6.0f is not elegant and not efficient. - // However, this is part of the emulation code to ensure exact match. - using namespace detail; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; - const float S_dec_b = block_amax / fp4_max * S_enc; - return static_cast(fminf(S_dec_b, TypeExtrema::max)); -} - -// Compute the global encode scale factor for a given global amax -__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { - using namespace detail; - constexpr float fp8_max = TypeExtrema::max; // 448.0f; - constexpr float fp4_max = TypeExtrema::max; // 6.0f; - float global_encode_scale = fp8_max * fp4_max / global_amax; - // If scale is infinity, return max value of float32 - global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); - // If global amax is 0 or infinity, return 1 - if (global_amax == 0.0f || global_encode_scale == 0.0f) { - return 1.0f; - } - return global_encode_scale; -} - -__device__ __forceinline__ uint32_t -get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> - &rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 - uint4 &random_uint4, int &rnd_idx) { - if (rnd_idx == 4) { - rnd_idx = 0; - random_uint4 = rng.generate4(); - } - - // Treat uint4 as an array of 4x uint32_t elements for indexing - const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); - const uint32_t rbits = rbits_arr[rnd_idx++]; - return rbits; -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( - const uint64_t in_4x, const float2 scale, const uint32_t rbits) { - uint16_t out_4x = 0; - constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - if constexpr (has_rs) { - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); - } else { - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - } - return *reinterpret_cast(&out_4x); -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, - const float2 scale, - const uint32_t rbits) { - constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; - uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. - if constexpr (is_blackwell) { - // NOTE: rbits unused for rn. - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b16 v0_bf16; \n\t" - ".reg.b16 v1_bf16; \n\t" - ".reg.b16 v2_bf16; \n\t" - ".reg.b16 v3_bf16; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" - "cvt.f32.bf16 v0, v0_bf16; \n\t" - "cvt.f32.bf16 v1, v1_bf16; \n\t" - "cvt.f32.bf16 v2, v2_bf16; \n\t" - "cvt.f32.bf16 v3, v3_bf16; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(in_4x), "l"(reinterpret_cast(scale))); - } else { - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - } - return reinterpret_cast(&out_4x)[0]; -} - -template -__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, - const float2 scale, - const uint32_t rbits) { - if constexpr (USE_STOCHASTIC_ROUNDING) { - return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); - } else { - return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); - } -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( - const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { - uint16_t out_4x = 0; - constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; - if constexpr (has_rs) { - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order - "}" - : "=h"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale)), "r"(rbits)); - } else { - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - } - return *reinterpret_cast(&out_4x); -} - -__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, - const float2 in23, - const float2 scale, - const uint32_t rbits) { - constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; - uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. - if constexpr (is_blackwell) { - // NOTE: rbits unused for rn. - asm volatile( - "{\n" - ".reg.b64 v01; \n\t" - ".reg.b64 v23; \n\t" - ".reg.b32 v0; \n\t" - ".reg.b32 v1; \n\t" - ".reg.b32 v2; \n\t" - ".reg.b32 v3; \n\t" - ".reg.b8 f0; \n\t" - ".reg.b8 f1; \n\t" - "mov.b64 {v0, v1} , %1; \n\t" - "mov.b64 {v2, v3} , %2; \n\t" - "mov.b64 v01, {v0, v1}; \n\t" - "mov.b64 v23, {v2, v3}; \n\t" - "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order - "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order - "mov.b64 {v1, v0}, v01; \n\t" - "mov.b64 {v3, v2}, v23; \n\t" - "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" - "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" - "mov.b32 %0, {f0, f1, f0, f1};\n\t" - "}" - : "=r"(out_4x) - : "l"(reinterpret_cast(in01)), - "l"(reinterpret_cast(in23)), - "l"(reinterpret_cast(scale))); - } else { - NVTE_DEVICE_ERROR( - "FP4 cvt PTX instructions are architecture-specific. " - "Try recompiling with sm_XXXa instead of sm_XXX."); - } - return reinterpret_cast(&out_4x)[0]; -} - -template -__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, - const float2 scale, - const uint32_t rbits) { - if constexpr (USE_STOCHASTIC_ROUNDING) { - return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); - } else { - return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); - } -} - -template -__global__ void __launch_bounds__(THREADS_NUM) - nvfp4_transpose_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, - const float *noop, const float *const amax_rowwise_ptr, - const float *const amax_colwise_ptr, const size_t rows, - const size_t cols, const size_t scale_stride, - const size_t scale_stride_t, const size_t *rng_state) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = - (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); - - using IType2 = typename ptx::FPx2; - - if constexpr (!COMPUTE_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - - const size_t rng_sequence = - threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; - const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; - const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - rng.init(rng_seed, rng_sequence, rng_offset); - uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; - - int rnd_idx = - 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; - const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; - - const size_t chunk_rows = rows - block_offset_Y; - - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; - const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const size_t tid_X_colwise = threadIdx.x; - const size_t tid_Y_t = tid_X_colwise; - // const size_t tid_X_t = 0; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; - const size_t scales_offset_X_t = scales_block_offset_X_t; - - const size_t SFs_per_row = cols / SCALE_DIM; - - const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; - const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; - - // Helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - - constexpr size_t in_mem = buff_size_aligned_in; - - constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; - constexpr size_t out_mem_colwise_data = buff_size_aligned_out; - constexpr size_t out_mem_rowwise_scales = 0; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); - fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); - - nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - // Compute a global encoding/decoding scaling factors for all S_dec_b - const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) - ? 1.0f - : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); - // NOTE: This is to match with how emulation code was written. - const float S_dec_rowwise = 1.0 / S_enc_rowwise; - - const float S_enc_colwise = (amax_colwise_ptr == nullptr) - ? S_enc_rowwise - : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); - const float S_dec_colwise = 1.0 / S_enc_colwise; - - float thread_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - -#pragma unroll - for (size_t stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - const size_t buff_offset_in = buff * BUFF_IN_SIZE; - const size_t buff_offset_out = buff * BUFF_OUT_SIZE; - const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; - - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], 0); - - float block_amax = 0.0f; - - // COLWISE scaling - if constexpr (RETURN_TRANSPOSE) { -#pragma unroll - for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { - const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; - const size_t in_thread_offset_X = thread_offset_X_colwise; - - const size_t out_t_thread_offset_Y = thread_offset_X_colwise; - const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; - - const size_t shmem_offset_base_colwise_in = - buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; - const size_t shmem_offset_base_colwise_out_t = - buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; - - block_amax = 0.0f; - float in_compute_colwise[SCALE_DIM]; - IType in_colwise_IType[SCALE_DIM]; - // 1. Read/Compute elements. Find NVFP4-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType block_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < SCALE_DIM; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); - } - block_amax = static_cast(block_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < SCALE_DIM; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = - (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); - - // Store scaling factors through SHMEM - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements - fp4e2m1x4 regs[SCALE_DIM / 4]; - -#pragma unroll - for (int e = 0; e < SCALE_DIM / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); - regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, - rbits); - } else { - const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); - const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); - regs[e] = mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); - } - } - - const int group = thread_lane / 16; - uint32_t val[2]; - uint32_t *regs_4x = reinterpret_cast(regs); - - // Helps reducing bank conflicts - switch (group) { - case 0: - val[0] = regs_4x[0]; - val[1] = regs_4x[1]; - break; - case 1: - val[0] = regs_4x[1]; - val[1] = regs_4x[0]; - - break; - } - uint32_t *out_t_data_sh_as_uint32_t = - reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); - out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; - out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; - } - } - - // ROWWISE scaling - { - const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; -#pragma unroll - for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { - const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; - - const size_t shmem_offset_base_rowwise_in = - buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; - const size_t shmem_offset_base_rowwise_out = - buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; - - const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; - - block_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find NVFP4-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - block_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const size_t j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = - (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - block_amax = fmaxf(block_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - block_amax = fmaxf(block_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); - out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else if constexpr (IS_CACHED_ACT_OP) { - const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); - out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else { - const int j = w * PACK_SIZE + 4 * e; - const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); - const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); - out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); - } - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_data_sh[shmem_offset_rowwise]); - } - } - } - - __builtin_assume(thread_amax >= 0); - thread_amax = fmaxf(thread_amax, block_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - - const size_t global_offset_Y_t = block_offset_Y_t; - const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; - - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, - reinterpret_cast(&out_data_sh[buff_offset_out])); - - if constexpr (RETURN_TRANSPOSE) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_t), global_offset_X_t, - global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } // end of stages - - // Vectorized store scaling factors through SHMEM - if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { - using ScalesVec = Vec; - const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; - ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); - const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; - const size_t count = // number of scales in Y dimension of this chunk - (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); - nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; - constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); - if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { - // Fast path: vectorized store when destination is properly aligned - scales_vec.store_to(dst); - } else { - // Safe path: element-wise store for tails or unaligned destinations - scales_vec.store_to_elts(dst, 0, count); - } - } - - destroy_barriers(mbar, is_master_thread); -#else - NVTE_DEVICE_ERROR("sm_100 or higher is required."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -template -__global__ void __launch_bounds__(THREADS_NUM) - nvfp4_transpose_kernel_2D(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - const __grid_constant__ CUtensorMap tensor_map_output_t, - nvfp4_scale_t *const scales_ptr, nvfp4_scale_t *const scales_t_ptr, - const float *noop, const float *const amax_rowwise_ptr, - const float *const amax_colwise_ptr, const size_t rows, - const size_t cols, const size_t scale_stride, - const size_t scale_stride_t, const size_t *rng_state) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = - (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); - - using IType2 = typename ptx::FPx2; - - if constexpr (!COMPUTE_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - const size_t rng_sequence = - threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; - const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; - const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; - - transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; - rng.init(rng_seed, rng_sequence, rng_offset); - uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; - - int rnd_idx = - 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x - - // NEW: 2D Block-based scaling constants - constexpr size_t BLOCK_DIM = 16; - constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 - constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 - constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile - constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; - const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; - - const size_t chunk_rows = rows - block_offset_Y; - - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; - const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const size_t tid_X_colwise = threadIdx.x; - const size_t tid_Y_t = tid_X_colwise; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; - const size_t scales_offset_X_t = scales_block_offset_X_t; - - const size_t SFs_per_row = cols / SCALE_DIM; - - const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; - const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; - - // Helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - - constexpr size_t in_mem = buff_size_aligned_in; - - constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; - constexpr size_t out_mem_colwise_data = buff_size_aligned_out; - constexpr size_t out_mem_rowwise_scales = 0; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); - fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); - - nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); - nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( - dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - // Compute a global encoding/decoding scaling factors for all S_dec_b - const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) - ? 1.0f - : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); - // NOTE: This is to match with how emulation code was written. - const float S_dec_rowwise = 1.0 / S_enc_rowwise; - - const float S_enc_colwise = (amax_colwise_ptr == nullptr) - ? S_enc_rowwise - : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); - const float S_dec_colwise = 1.0 / S_enc_colwise; - - const size_t warp_id = threadIdx.x / 32; - const size_t lane_id = threadIdx.x % 32; - float thread_amax = 0.0f; - const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; - - // Helper function for warp reduction - auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { -#pragma unroll - for (int delta = 8; delta >= 1; delta /= 2) { - float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); - thread_amax = fmaxf(thread_amax, other_amax); - } - return thread_amax; - }; - - initialize_barriers(mbar, is_master_thread); - - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - -#pragma unroll - for (size_t stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - const size_t buff_offset_in = buff * BUFF_IN_SIZE; - const size_t buff_offset_out = buff * BUFF_OUT_SIZE; - const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; - - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], 0); - - float block_amax = 0.0f; - -#pragma unroll - for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; - const size_t block_in_tile_y = block_iter; - const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; - - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - for (int elem = 0; elem < BLOCK_DIM; elem += 2) { - const size_t elem_0_row = block_iter * BLOCK_DIM + elem; - const size_t elem_1_row = elem_0_row + 1; - const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; - const size_t elem_1_col = elem_0_col; - - const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; - const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; - - IType2 val_2x; - val_2x.x = in_sh[shmem_offset_0]; - val_2x.y = in_sh[shmem_offset_1]; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); - } - - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else { - for (int elem = 0; elem < BLOCK_DIM; ++elem) { - const size_t elem_row = block_iter * BLOCK_DIM + elem; - const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; - - // Bounds checking - const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); - const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); - if (!row_out_of_bounds && !col_out_of_bounds) { - const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; - float elt = static_cast(in_sh[shmem_offset]); - - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset] = static_cast(elt); - } - - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } - } - // Warp reduction to get block amax - block_amax = warp_reduce_amax(thread_amax, block_in_warp); - - if (lane_id == 0 || lane_id == 16) { - block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; - } - } - - // sync thread to ensure block_amax_matrix is done storing - __syncthreads(); - - // COLWISE scaling - if constexpr (RETURN_TRANSPOSE) { -#pragma unroll - for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { - const size_t block_in_tile_y = it; - const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; - - const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; - const size_t in_thread_offset_X = thread_offset_X_colwise; - - const size_t out_t_thread_offset_Y = thread_offset_X_colwise; - const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; - - const size_t shmem_offset_base_colwise_in = - buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; - const size_t shmem_offset_base_colwise_out_t = - buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; - - block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; - float in_compute_colwise[SCALE_DIM]; - IType in_colwise_IType[SCALE_DIM]; - // 3. Scale elements - - // Load data in - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { -#pragma unroll - for (int i = 0; i < SCALE_DIM; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - } - } else { - for (int i = 0; i < SCALE_DIM; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_colwise); - - // // Store scaling factors through SHMEM - const size_t scale_idx_sh = - tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; - out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - - fp4e2m1x4 regs[SCALE_DIM / 4]; -#pragma unroll - for (int e = 0; e < SCALE_DIM / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); - regs[e] = mul_cvt_bf16_to_fp4_4x(elts, block_scale_inverse_2x, - rbits); - } else { - const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); - const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); - regs[e] = mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); - } - } - - const int group = thread_lane / 16; - uint32_t val[2]; - uint32_t *regs_4x = reinterpret_cast(regs); - - // Helps reducing bank conflicts - switch (group) { - case 0: - val[0] = regs_4x[0]; - val[1] = regs_4x[1]; - break; - case 1: - val[0] = regs_4x[1]; - val[1] = regs_4x[0]; - break; - } - uint32_t *out_t_data_sh_as_uint32_t = - reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); - out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; - out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; - } - } - - // ROWWISE scaling - { - const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; -#pragma unroll - for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { - const size_t block_in_tile_y = it; - const size_t block_in_tile_x = tid_X_rowwise; - const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; - - const size_t shmem_offset_base_rowwise_in = - buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; - const size_t shmem_offset_base_rowwise_out = - buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; - - block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; - float in_compute_rowwise[SCALE_DIM]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find NVFP4-block AMAX - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); - } - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const size_t j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (COMPUTE_ACTIVATIONS) { - elt = OP(elt, {}); - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; - } - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 - const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 4; ++e) { - const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); - IType2 in01; - IType2 in23; - if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { - const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); - out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else if constexpr (IS_CACHED_ACT_OP) { - const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); - out.data.elt[e] = mul_cvt_bf16_to_fp4_4x( - elts, block_scale_inverse_2x, rbits); - } else { - const int j = w * PACK_SIZE + 4 * e; - const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); - const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); - out.data.elt[e] = mul_cvt_fp32_to_fp4_4x( - in01, in23, block_scale_inverse_2x, rbits); - } - } - - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; - out.store_to(&out_data_sh[shmem_offset_rowwise]); - } - } - } - - __builtin_assume(thread_amax >= 0); - thread_amax = fmaxf(thread_amax, block_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - - const size_t global_offset_Y_t = block_offset_Y_t; - const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; - - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, - reinterpret_cast(&out_data_sh[buff_offset_out])); - - if constexpr (RETURN_TRANSPOSE) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_t), global_offset_X_t, - global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } // end of stages - - // Vectorized store scaling factors through SHMEM - if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { - using ScalesVec = Vec; - const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; - ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); - const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; - const size_t count = // number of scales in Y dimension of this chunk - (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); - nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; - constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); - if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { - // Fast path: vectorized store when destination is properly aligned - scales_vec.store_to(dst); - } else { - // Safe path: element-wise store for tails or unaligned destinations - scales_vec.store_to_elts(dst, 0, count); - } - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace nvfp4_transpose -#endif // FP4_TYPE_SUPPORTED - -template -void nvfp4_quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, - const QuantizationConfig *quant_config, cudaStream_t stream) { -#if FP4_TYPE_SUPPORTED - bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; - - // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to - // return the transposed data. - // TODO(Frank): Is there a better way to do this? - bool return_transpose = output->has_columnwise_data(); - - using namespace nvfp4_transpose; - using namespace ptx; - - checkCuDriverContext(stream); - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "input"); - CheckOutputTensor(*output, "output", false); - - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); - NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - if (return_transpose) { - NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); - NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), - "Transposed output must have FP4 type."); - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Transposed scaling tensor must be allocated"); - } - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - NVTE_CHECK(rows % 32 == 0, - "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA - NVTE_CHECK(cols % 32 == 0, - "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_NUM; - - const size_t scale_stride = output->scale_inv.shape[1]; - const size_t scale_stride_transpose = - return_transpose ? output->columnwise_scale_inv.shape[1] : 0; - - nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); - nvfp4_scale_t *const scales_transpose_ptr = - reinterpret_cast(output->columnwise_scale_inv.dptr); - - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); - const float *const amax_colwise_ptr = - reinterpret_cast(output->columnwise_amax.dptr); - - const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; - const size_t *rng_state = nullptr; - if (rng_state_tensor != nullptr) { - Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); - NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, - "RNG state should contain 2 64-bit values."); - NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, - "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); - rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); - } - - using IType = bf16; - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - alignas(64) CUtensorMap tensor_map_output_transpose{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, - sizeof(IType) * 8); - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, - 4); - if (return_transpose) { - create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, - BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); - } - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); - - constexpr size_t in_mem = buff_size_aligned_in; - - constexpr size_t out_data_mem = buff_size_aligned_out; - constexpr size_t out_data_transpose_mem = buff_size_aligned_out; - constexpr size_t out_scales_transpose_mem = buff_size_scales; - - constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; - - constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; - - TRANSFORMER_ENGINE_SWITCH_CONDITION( - use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = nvfp4_transpose_kernel; - - if constexpr (use_2d_quantization) { - kernel = nvfp4_transpose_kernel_2D; - } - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - });); -#else - NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif // FP4_TYPE_SUPPORTED -} -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_NVFP4_TRANSPOSE_CUH_ From c57ffc51a83ce16f7961df5b7b65c08080eb6639 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 3 Nov 2025 11:12:49 -0500 Subject: [PATCH 099/141] [JAX] L1_jax_distributed_test suit with individual executions (#2321) * L1 rework Signed-off-by: Phuong Nguyen * comment out test_multi_process_grouped_gemm for now Signed-off-by: Phuong Nguyen * rm e5m2 from test norm + MXFP8 Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- qa/L1_jax_distributed_unittest/test.sh | 36 +++++++++++++++++++++++--- tests/jax/multi_process_launch.sh | 10 ++++++- tests/jax/test_custom_call_compute.py | 7 ++++- 3 files changed, 48 insertions(+), 5 deletions(-) diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 42b70a28e..f4ea2dd68 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -2,12 +2,42 @@ # # See LICENSE for license information. -set -xe +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" +export NVTE_JAX_UNITTEST_LEVEL="L1" + # Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate. -XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* -SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh +export XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_dense.xml $TE_PATH/tests/jax/test_distributed_dense.py || test_fail "test_distributed_dense.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_helper.xml $TE_PATH/tests/jax/test_distributed_helper.py || test_fail "test_distributed_helper.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_layernorm.xml $TE_PATH/tests/jax/test_distributed_layernorm.py || test_fail "test_distributed_layernorm.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_mlp.xml $TE_PATH/tests/jax/test_distributed_layernorm_mlp.py || test_fail "test_distributed_layernorm_mlp.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_softmax.xml $TE_PATH/tests/jax/test_distributed_softmax.py || test_fail "test_distributed_softmax.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" + +# TODO(Phuong): add this test back after it is verified +# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh index fcb066de7..d430e0f41 100644 --- a/tests/jax/multi_process_launch.sh +++ b/tests/jax/multi_process_launch.sh @@ -18,6 +18,14 @@ do CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 & done -CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS | tee stdout_multi_process.txt wait + +RET=0 +if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 +fi + +rm -f stdout_multi_process.txt +exit "$RET" diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 11ff9d061..cecdb3121 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -605,7 +605,12 @@ def test_norm_forward_with_tensor_scaling_fp8( ) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) - @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) + @pytest.mark.parametrize( + "out_dtype", + [ + jnp.float8_e4m3fn, + ], + ) def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype ): From 3d76218ee75626f7d749025d8d877468547f78c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 4 Nov 2025 01:26:39 +0100 Subject: [PATCH 100/141] [PyTorch debug] Fixes to debug tests failures (#2268) * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix: Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L0_pytorch_debug_unittest/test.sh | 41 +++++-- tests/pytorch/debug/run_distributed.py | 7 +- .../test_switch_to_nondebug_mode.yaml | 11 ++ tests/pytorch/debug/test_perf.py | 114 +++++++++--------- .../debug/features/_test_dummy_feature.py | 46 ++++++- 5 files changed, 139 insertions(+), 80 deletions(-) create mode 100644 tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index 9980ccfb0..b6c42109b 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -2,7 +2,19 @@ # # See LICENSE for license information. +function error_exit() { + echo "Error: $1" + exit 1 +} +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} @@ -14,24 +26,27 @@ mkdir -p "$XML_LOG_DIR" # Nvinspect will be disabled if no feature is active. : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} -FAIL=0 - # It is not installed as a requirement, # because it is not available on PyPI. pip uninstall -y nvdlfw-inspect pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git -pip install pytest==8.2.1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pip install pytest==8.2.1 || error_exit "Failed to install pytest" +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 - -exit $FAIL +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index fee2189fa..358841943 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -668,11 +668,12 @@ def _run_test_with_combinations( _init_distributed() test_log_expert_parallel() - for parallel_mode in ["column", "row"]: - for gather_weight in [True, False]: - test_log_distributed(parallel_mode, gather_weight) if fp8_available: + for parallel_mode in ["column", "row"]: + for gather_weight in [True, False]: + test_log_distributed(parallel_mode, gather_weight) + for parallel_mode in ["row", "column"]: test_disable_fp8_layer(parallel_mode) diff --git a/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml b/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml new file mode 100644 index 000000000..224be4618 --- /dev/null +++ b/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml @@ -0,0 +1,11 @@ +test_switch_to_nondebug_mode: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + TestDummyFeature: + enabled: True + inspect_only_once: True + tensors: [weight, activation, gradient, output, wgrad, dgrad] + gemms: [wgrad, dgrad, fprop] + diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index ad40c31c0..c8c9ae3c1 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -6,74 +6,70 @@ import pytest import torch import transformer_engine.pytorch as te -import time import nvdlfw_inspect.api as debug_api from transformer_engine.debug.pytorch.debug_state import TEDebugState -def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): - debug_api.end_debug() - TEDebugState._reset() - if debug_tools_initialized: - # This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS. - # So after 1 warm-up iteration, this layers should work in non-debug mode. - debug_api.initialize( - config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs - ) - - try: - if layer == "linear": - model = torch.nn.Sequential( - te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") - ).cuda() - NUM_ITERS = 1800 - elif layer == "transformer": - model = torch.nn.Sequential( - te.TransformerLayer(1, 1, 1, name="transformer1"), - te.TransformerLayer(1, 1, 1, name="transformer2"), - ).cuda() - NUM_ITERS = 200 - - NUM_INVOCATIONS_PER_ITER = 10 +@pytest.mark.parametrize("use_microbatching", [False, True]) +def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbatching): + """ + Test that layers switch to non-debug mode when no features are active. - x = torch.randn(1, 1, 1).cuda() + Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None). + The TE should: + 1. Call inspect_tensor_enabled to check if feature is needed + 2. Never call inspect_tensor + 3. Allow layers to switch to non-debug mode for optimal performance, + so that inspect_tensor_enabled is never called again. - y = model(x) - y.sum().backward() - debug_api.step() - torch.cuda.synchronize() + Tests both with and without microbatching to ensure proper behavior in both scenarios. + """ - time_start = time.time() - for i in range(NUM_ITERS): - for _ in range(NUM_INVOCATIONS_PER_ITER): + try: + debug_api.initialize( + config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml", + feature_dirs=feature_dirs, + ) + import transformer_engine.debug.features._test_dummy_feature as dummy_feature + + # Reset counters + dummy_feature._inspect_tensor_enabled_call_count = 0 + dummy_feature._inspect_tensor_call_count = 0 + + model = te.Linear(256, 256, name="test_linear").cuda() + x = torch.randn(8, 256, 256).cuda() + + # Run multiple iterations + for i in range(20): + if use_microbatching: + # Alternate between first and non-first microbatch + is_first_microbatch = i % 2 == 0 + y = model(x, is_first_microbatch=is_first_microbatch) + else: + # Run without specifying is_first_microbatch y = model(x) - y.sum().backward() - if debug_tools_initialized: - debug_api.step() - torch.cuda.synchronize() - time_end = time.time() - - finally: - if debug_tools_initialized: - debug_api.end_debug() - - return time_end - time_start - - -@pytest.mark.parametrize("layer", ["linear", "transformer"]) -def test_cpu_overhead(layer, configs_dir, feature_dirs): - # runs one layer many times on very small tensor - # - gpu time should be negligible, so time should be dominated by cpu time. - # if layers does not invoke any feature in current iteration, - # then it changed into non-debug mode and should not have any non-negligible cpu overhead - # compared to layer without debug tools initialized. - - with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs) - without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs) + y.sum().backward() + debug_api.step() + + # Verify inspect_tensor_enabled was called only once per tensor + # (activation, weight, gradient, output, wgrad, dgrad) + enabled_call_count = dummy_feature._inspect_tensor_enabled_call_count + microbatch_info = "with microbatching" if use_microbatching else "without microbatching" + assert enabled_call_count == 6, ( + f"inspect_tensor_enabled was called {enabled_call_count} times ({microbatch_info}), " + "but should be called 6 times to check if feature is needed for each tensor " + "(activation, weight, gradient, output, wgrad, dgrad)" + ) - print(f"with_debug_tools: {with_debug_tools} s") - print(f"without_debug_tools: {without_debug_tools} s") + # Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None) + inspect_call_count = dummy_feature._inspect_tensor_call_count + assert inspect_call_count == 0, ( + f"inspect_tensor was called {inspect_call_count} times ({microbatch_info}), " + "but should never be called when inspect_tensor_enabled returns (False, None)" + ) - assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin + finally: + debug_api.end_debug() + TEDebugState._reset() diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index c8a31a343..4dee97b70 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -7,19 +7,55 @@ from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.api import TEConfigAPIMapper +# Module-level counters for tracking invocations +# NOTE: These must be accessed via the full module path +# (transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count) +# to ensure the same module instance is used when the feature is loaded by the debug framework +# and when imported by tests. Using just the variable name would create separate instances +# in different import contexts. +_inspect_tensor_enabled_call_count = 0 +_inspect_tensor_call_count = 0 + @Registry.register_feature(namespace="transformer_engine") class TestDummyFeature(TEConfigAPIMapper): """ - This is feature used only in tests. It invokes look_at_tensor_before_process - and does nothing. + This is feature used only in tests. It invokes inspect_tensor and does nothing. If no features are used, then TE layer automatically switches to the non-debug mode. This feature is invoked for each GEMM to prevent this behavior. + + Config options: + - inspect_only_once: if True, return (False, None) from inspect_tensor_enabled to test caching behavior + + Note: This feature always tracks invocations for testing purposes. """ @api_method - def inspect_tensor_enabled(self, *_args, **_kwargs): - """API call used to determine whether to run look_at_tensor_before_process - in the forward pass.""" + def inspect_tensor_enabled(self, config, *_args, **_kwargs): + """API call used to determine whether to run inspect_tensor in the forward pass. + + Always tracks calls for testing purposes. + + Returns: + - If inspect_only_once=True in config: returns (False, None) - check once, never call inspect_tensor + - Otherwise: returns True - feature is always enabled + """ + # Access counter via full module path to ensure we're modifying the same module-level + # variable regardless of import context (debug framework vs test import) + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self + + dummy_feature._inspect_tensor_enabled_call_count += 1 + + inspect_only_once = config.get("inspect_only_once", False) + if inspect_only_once: + return False, None return True + + @api_method + def inspect_tensor(self, _config, *_args, **_kwargs): + """This method does nothing but always tracks invocations for testing.""" + # Access counter via full module path to ensure shared state across import contexts + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self + + dummy_feature._inspect_tensor_call_count += 1 From 77a006352bbe67b6e66fb5a2c5be8ff2d0dd9cde Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:39:41 +0100 Subject: [PATCH 101/141] [PyTorch Debug] Add max_blockwise_dynamic_range stats (#2137) * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/debug/test_log.py | 198 +++++++++++++++++- .../debug/features/log_tensor_stats.py | 77 ++++++- .../debug/features/utils/stats_buffer.py | 6 +- .../debug/features/utils/stats_computation.py | 116 ++++++++++ 4 files changed, 389 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index e9d074821..0f833d41f 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -18,7 +18,11 @@ ) from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.debug.pytorch.debug_state import TEDebugState - +from transformer_engine.debug.features.utils.stats_computation import ( + compute_max_blockwise_dynamic_range, + BlockwiseDynamicRangeStat, +) +import math fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -154,7 +158,7 @@ def test_sanity(feature_dirs): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -def test_numerics(fp8_recipe, feature_dirs): +def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if not fp8_available: pytest.skip(reason_for_no_fp8) if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): @@ -210,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs): assert overflows == pytest.approx(expected.cpu(), abs=1e-4) +LOG_HIGH_PRECISION_CONFIG = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + stats: + - dynamic_range + - max_blockwise_dynamic_range: + block_size: 4 + dims: 1 + - max_blockwise_dynamic_range: + block_size: 4 + dims: 2 + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + + +@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"]) +def test_log_stats_numerics(feature_dirs, tensor_name): + """Check correctness of dynamic range and max blockwise dynamic range stats. + + Tests different tensor types: + - activation/weight: use both orientations (rowwise + columnwise), takes max + - gradient/dgrad: use single orientation (rowwise only) + """ + log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG + + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: + # There is 1024 x 1024 tensor with very small epsilon values in almost all elements, + # one row of large value A and three rows of large value B. + epsilon = 1e-10 + A = 1000 + B = 50 + tensor = torch.zeros(1024, 1024).cuda() + epsilon + tensor[0, :] = A + tensor[1:4, :] = B + + debug_api.transformer_engine.inspect_tensor( + layer_name="layer_name", + tensor_name=tensor_name, + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=None, + rowwise_quantized_tensor=None, + columnwise_quantized_tensor=None, + ) + debug_api.step() + + output = read_log(log_dir) + + max_over_orientations = tensor_name in ["activation", "weight"] + max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else "" + + # Track which stats were found to ensure all are present + found_dims_1 = False + found_dims_2 = False + found_dynamic_range = False + + for line in output.splitlines(): + if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line: + max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) + if max_over_orientations: + # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) + expected = math.log2(A) - math.log2(B) + else: + # Rowwise blocks have uniform values -> dynamic_range = 0 + expected = 0 + assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx( + expected, abs=1e-4 + ) + found_dims_1 = True + elif ( + f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line + ): + max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) + # For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows + expected = math.log2(A) - math.log2(B) + assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( + expected, abs=1e-4 + ) + found_dims_2 = True + elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line: + dynamic_range = float(line.split("value=")[1]) + expected = math.log2(A) - math.log2(epsilon) + assert dynamic_range == pytest.approx(expected, abs=1e-4) + found_dynamic_range = True + + # Ensure all expected stats were found in the output + assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output" + assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output" + assert found_dynamic_range, "dynamic_range not found in output" + + @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): if not fp8_available: @@ -256,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): debug_api.end_debug() TEDebugState._reset() + + +def test_compute_max_blockwise_dynamic_range_direct(): + """Direct unit test for compute_max_blockwise_dynamic_range function. + + Tests the function with various configurations to ensure correct behavior + for different block sizes, dimensions, and orientation settings. + """ + # Create test tensor with uniform rows but mixed columns + # Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01 + epsilon = 0.01 + A = 1000.0 + B = 50.0 + tensor = torch.zeros(1024, 1024).cuda() + epsilon + tensor[0, :] = A + tensor[1:4, :] = B + + # Test 1: dims=1, max_over_orientations=False (rowwise only) + # Rowwise blocks have uniform values -> dynamic_range should be 0 + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + assert result.item() == pytest.approx( + 0.0, abs=1e-4 + ), "Rowwise 1D blocks with uniform values should have dynamic_range=0" + + # Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise) + # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + expected = math.log2(A) - math.log2(B) + assert result.item() == pytest.approx(expected, abs=1e-4), ( + f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got" + f" {result.item()}" + ) + + # Test 3: dims=2, block_size=4 (4x4 tiles) + # 2D blocks span multiple rows -> always have mixed values + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + expected = math.log2(A) - math.log2(B) + assert result.item() == pytest.approx(expected, abs=1e-4), ( + f"2D blocks should capture mixed values from different rows, expected {expected}, got" + f" {result.item()}" + ) + + # Test 4: Different block size + # With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon] + # So max=A, min=epsilon (not B anymore) + stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B + assert result.item() == pytest.approx( + expected, abs=1e-4 + ), f"Block size 8 should work correctly, expected {expected}, got {result.item()}" + + # Test 5: Tensor with all uniform values -> dynamic_range should be 0 + uniform_tensor = torch.ones(64, 64).cuda() * 42.0 + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) + result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config) + assert result.item() == pytest.approx( + 0.0, abs=1e-4 + ), "Uniform tensor should have dynamic_range=0" + + # Test 6: 3D tensor flattening validation using 2D/3D comparison + # Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2 + # Then reshape to 3D and compute again - results should match if flattening is correct + tensor_2d = torch.tensor( + [ + [1.0, 1.0, 10.0, 10.0], + [1.0, 1.0, 10.0, 10.0], + [100.0, 100.0, 1000.0, 1000.0], + [100.0, 100.0, 1000.0, 1000.0], + ] + ).cuda() + + # Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100) + stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False) + result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config) + + # Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct + tensor_3d = tensor_2d.reshape(2, 2, 4) + result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config) + + assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), ( + "3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got" + f" 2D={result_2d.item()}, 3D={result_3d.item()}" + ) + + print("All direct tests for compute_max_blockwise_dynamic_range passed!") diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index e917cf9a0..ff37e659a 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -4,7 +4,7 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Optional +from typing import Dict, Optional, List import torch @@ -19,6 +19,10 @@ from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params +from transformer_engine.debug.features.utils.stats_computation import ( + add_max_blockwise_dynamic_range_stats, + BlockwiseDynamicRangeStat, +) @Registry.register_feature(namespace="transformer_engine") @@ -44,7 +48,14 @@ class LogTensorStats(BaseLogTensorStats): - l1_norm - l2_norm - cur_amax – maximal absolute value of a tensor, - - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` + - dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)` + - max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor. + If tensor and its transpose is needed in training, this stat is computed for both orientations and the maximum is returned. + For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile. + + - block_size: int, default = 32 + - dims: int, default = 1, allowed values are 1 and 2 + tensors/tensors_struct: List[str] list of tensors to log @@ -88,6 +99,60 @@ class LogTensorStats(BaseLogTensorStats): stats: [dynamic_range] """ + def _is_supported_stat(self, stat: str | Dict): + """Returns True if the stat is supported by this feature, False otherwise.""" + if isinstance(stat, dict): + stat_name = list(stat.keys())[0] + if stat_name == "max_blockwise_dynamic_range": + stat_dict = stat[stat_name] + if not isinstance(stat_dict, dict): + return False + # Ensure only supported keys are present + allowed_keys = {"block_size", "dims"} + if any(k not in allowed_keys for k in stat_dict.keys()): + return False + block_size = stat_dict.get("block_size", 32) + dims = stat_dict.get("dims", 1) + # Type and value validation + if not isinstance(block_size, int) or not isinstance(dims, int): + return False + if block_size > 0 and dims in [1, 2]: + return True + return False + return stat in BaseLogTensorStats._get_supported_stats_list(None) | { + "cur_amax", + "dynamic_range", + } + + def _parse_max_blockwise_dynamic_range_stats( + self, stats: List[str | Dict], tensor_name: str + ) -> List[str | BlockwiseDynamicRangeStat]: + """ + Adds all max_blockwise_dynamic_range stats to the stat computation logic. + Changes the types of the stats from Dict to BlockwiseDynamicRangeStat named tuple, + for other stats nothing is changed. + + For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}], + it will be changed to [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=True)] + or [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=False)] depending on tensor_name. + + """ + max_over_orientations = tensor_name in ["activation", "weight"] + parsed_stats = [] + for stat in stats: + if isinstance(stat, dict): + block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32) + dims = stat["max_blockwise_dynamic_range"].get("dims", 1) + + # Register stat and return the named tuple + parsed_stat = add_max_blockwise_dynamic_range_stats( + block_size, dims, max_over_orientations + ) + parsed_stats.append(parsed_stat) + else: + parsed_stats.append(stat) + return parsed_stats + def _get_supported_stats_list(self): """Returns stats this feature can log.""" return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"} @@ -147,14 +212,16 @@ def inspect_tensor( ) for stat in config["stats"]: - assert ( - stat in self._get_supported_stats_list() + assert self._is_supported_stat( + stat ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." + stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"], tensor_name) + STATS_BUFFERS.try_add_buffer( layer_name=layer_name, tensor_name=tensor_name, - stats=config["stats"], + stats=stats, options=options, reduction_group=reduction_group, reduce_within_microbatch=reduce_within_microbatch, diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 20236fb95..b5b462f5a 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -130,8 +130,12 @@ def log(self): for stat_name in self.stats_to_log: combiner = STATS[stat_name][1] stat_value = combiner(gathered_helper_stats) + + # Convert stat key to string for logging (uses __str__ for named tuples) + stat_name_str = str(stat_name) + MetricLogger.log_scalar( - f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration + f"{self.layer_name}_{self.tensor_name}_{stat_name_str}", stat_value, self.iteration ) output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = ( stat_value # for debugging purposes diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 2fa6985ac..8c480441c 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -7,12 +7,25 @@ """ import math +from collections import namedtuple + import torch import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.common.recipe import Format +class BlockwiseDynamicRangeStat( + namedtuple("BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"]) +): + """Named tuple representing a blockwise dynamic range statistic configuration.""" + + def __str__(self) -> str: + """Convert to string representation for stat name. Used for logging.""" + suffix = "_max_over_orientations" if self.max_over_orientations else "" + return f"max_blockwise_dynamic_range_block_size_{self.block_size}_dims_{self.dims}{suffix}" + + @torch.compile def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" @@ -26,6 +39,7 @@ def _compute_dynamic_range_top(tensor): return torch.log2(amax) +@torch.compile def _compute_dynamic_range_bottom(tensor): """Computes the log2 of the amin of the tensor""" tensor_abs = tensor.abs() @@ -37,6 +51,76 @@ def _compute_dynamic_range_bottom(tensor): return torch.log2(amin) +def compute_max_blockwise_dynamic_range(tensor, stat_config): + """ + Computes maximum blockwise dynamic range (log2 max/min_nonzero) within blocks. + + Flattens tensor to 2D and computes maximum dynamic range within blocks. If max_over_orientations + is True, computes for both rowwise and columnwise orientations and returns the maximum, + capturing the worst-case scenario regardless of how the tensor is used in GEMM operations. + If False, computes only for rowwise orientation. + + Returns 0 if all blocks are zeros, otherwise computes dynamic range over non-zero blocks. + + Args: + tensor: Input tensor (will be flattened to 2D) + stat_config: BlockwiseDynamicRangeStat named tuple with: + - block_size: Size of blocks (int) + - dims: 1 for 1D blocks (consecutive elements), 2 for 2D blocks (tiles) + - max_over_orientations: If True, compute max over rowwise and columnwise orientations + """ + # Extract parameters from stat_config + block_size = stat_config.block_size + dims = stat_config.dims + max_over_orientations = stat_config.max_over_orientations + + def _compute_for_one_orientation(tensor): + total_numel = tensor.numel() + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + + # torch.compile friendly code - standard ** power does not work with jit + total_block_size = block_size * block_size if dims == 2 else block_size + assert ( + total_numel % total_block_size == 0 + ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." + + tensor = tensor.abs().float() + if dims == 1: + tensor = tensor.reshape(-1, block_size) + per_block_amax = tensor.amax(dim=1) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1) + else: + # We want to have tensor of shape [nr_blocks, block_size, block_size], + # where each block is a block_size x block_size tile of the original tensor. + dim_y = tensor.shape[-1] // block_size + tensor = ( + tensor.reshape(-1, block_size, dim_y, block_size) + .permute(0, 2, 1, 3) + .reshape(-1, block_size, block_size) + ) + per_block_amax = tensor.amax(dim=(1, 2)) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) + + # Identify blocks that contain any non-zero element + nonzero_blocks = per_block_amax != 0 + dynamic_range_per_block = torch.where( + nonzero_blocks, + torch.log2(per_block_amax) - torch.log2(per_block_amin), + torch.zeros_like(per_block_amax, dtype=torch.float32), + ) + return dynamic_range_per_block.max() + + # Flatten to 2D + tensor_2d = tensor.reshape(-1, tensor.shape[-1]) + if max_over_orientations: + return max( + _compute_for_one_orientation(tensor_2d), # Rowwise orientation + _compute_for_one_orientation(tensor_2d.transpose(-2, -1)), # Columnwise orientation + ) + return _compute_for_one_orientation(tensor_2d) + + +@torch.compile def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) @@ -45,6 +129,7 @@ def compute_variance(variances, numels, sums): return var +@torch.compile def compute_std(variances, numels, sums): """Computates standard deviation.""" return torch.sqrt(compute_variance(variances, numels, sums)) @@ -316,6 +401,37 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} +def add_max_blockwise_dynamic_range_stats( + block_size: int, dims: int, max_over_orientations: bool = False +): + """Register max_blockwise_X_dynamic_range stats for the recipe. + + Args: + block_size: Size of blocks for computing blockwise dynamic range + dims: 1 for 1D blocks, 2 for 2D blocks + max_over_orientations: Whether to compute max over rowwise and columnwise orientations + + Returns: + BlockwiseDynamicRangeStat named tuple representing this stat (used as the stat key) + """ + # Use named tuple directly as the stat key - this is cleaner than string keys + stat_key = BlockwiseDynamicRangeStat(block_size, dims, max_over_orientations) + + if stat_key in stats_to_num: + return stat_key # already registered + + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + stats_to_num[stat_key] = len(stats_to_num) + DEPENDENCIES[stat_key] = {stat_key} + + STATS[stat_key] = ( + lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range(x, _stat_key), + lambda buffers, _stat_key=stat_key: max(_get(buffers, _stat_key)), + ) + + return stat_key + + for _columnwise in [True, False]: for _recipe_name in [ "", # default recipe From b6020e3bce7e0a22c6bf988ddee578943c60821f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 5 Nov 2025 23:49:11 +0100 Subject: [PATCH 102/141] [JAX] Fix bug with pre scale bias (#2300) * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski --- transformer_engine/jax/flax/transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 1eafed413..42c945124 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -197,6 +197,7 @@ def __call__( fused_scale_factor = scale_factor if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias + bias = None def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask""" From dcaca2a67ee0c390cb900a4c29168aef6ac198d5 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 5 Nov 2025 16:11:53 -0800 Subject: [PATCH 103/141] [JAX] Try to use pre-downloaded dataset artifacts first (#2345) * Try to use pre-downloaded dataset artifacts first Signed-off-by: Jeremy Berchtold * Set HF_HUB_OFFLINE to disable any network calls to HF when the pre-downloaded dataset is available Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- examples/jax/encoder/common.py | 55 ++++++++++++++++--- .../encoder/test_model_parallel_encoder.py | 4 +- examples/jax/encoder/test_multigpu_encoder.py | 4 +- .../encoder/test_multiprocessing_encoder.py | 4 +- .../jax/encoder/test_single_gpu_encoder.py | 4 +- examples/jax/mnist/test_single_gpu_mnist.py | 4 +- 6 files changed, 57 insertions(+), 18 deletions(-) diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index 9ffcfe57d..819fdf443 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -3,6 +3,9 @@ # See LICENSE for license information. """Shared functions for the encoder tests""" from functools import lru_cache +import os +import pathlib +import zipfile import jax import jax.numpy @@ -120,12 +123,48 @@ def get_quantization_recipe_from_name_string(name: str): raise ValueError(f"Invalid quantization_recipe, got {name}") -def hf_login_if_available(): - """Login to HF hub if available""" - try: - from huggingface_hub import login +@lru_cache(maxsize=None) +def _get_example_artifacts_dir() -> pathlib.Path: + """Path to directory with pre-downloaded datasets""" - login() - except Exception as e: - print(e) - pass + # Check environment variable + path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH") + if path: + return pathlib.Path(path).resolve() + + # Fallback to path in root dir + root_dir = pathlib.Path(__file__).resolve().parent.parent.parent + return root_dir / "artifacts" / "examples" / "jax" + + +def _unpack_cached_dataset(artifacts_dir: pathlib.Path, folder_name: str) -> None: + """Unpack a cached dataset if available""" + dataset_dir = artifacts_dir / folder_name + if not dataset_dir.exists(): + print(f"Cached dataset {folder_name} not found at {dataset_dir}, skipping unpack") + return + + # Disable any HF network calls since the dataset is cached locally + os.environ["HF_HUB_OFFLINE"] = "1" + + for filename in os.listdir(dataset_dir): + filepath = dataset_dir / filename + if not filename.endswith(".zip"): + continue + print(f"Unpacking cached dataset {folder_name} from {filepath}") + + with zipfile.ZipFile(filepath, "r") as zip_ref: + zip_ref.extractall(pathlib.Path.home() / ".cache" / "huggingface") + print( + f"Unpacked cached dataset {folder_name} to" + f" {pathlib.Path.home() / '.cache' / 'huggingface'}" + ) + + +# This is cached so we don't have to unpack datasets multiple times +@lru_cache(maxsize=None) +def unpack_cached_datasets_if_available() -> None: + """Unpack cached datasets if available""" + artifacts_dir = _get_example_artifacts_dir() + _unpack_cached_dataset(artifacts_dir, "mnist") + _unpack_cached_dataset(artifacts_dir, "encoder") diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index c6d867ef9..a3935da9f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -23,14 +23,14 @@ is_bf16_supported, get_quantization_recipe_from_name_string, assert_params_sufficiently_sharded, - hf_login_if_available, + unpack_cached_datasets_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode -hf_login_if_available() +unpack_cached_datasets_if_available() DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index 1004dd2dd..80a2b043c 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -22,14 +22,14 @@ from common import ( is_bf16_supported, get_quantization_recipe_from_name_string, - hf_login_if_available, + unpack_cached_datasets_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode -hf_login_if_available() +unpack_cached_datasets_if_available() DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 9605adf77..4d2141116 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -27,13 +27,13 @@ is_mxfp8_supported, is_nvfp4_supported, get_quantization_recipe_from_name_string, - hf_login_if_available, + unpack_cached_datasets_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -hf_login_if_available() +unpack_cached_datasets_if_available() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 81f2d6c74..7835b08b2 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -19,13 +19,13 @@ from common import ( is_bf16_supported, get_quantization_recipe_from_name_string, - hf_login_if_available, + unpack_cached_datasets_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode -hf_login_if_available() +unpack_cached_datasets_if_available() PARAMS_KEY = "params" DROPOUT_KEY = "dropout" diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 2e9d56e93..62f7954e0 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -25,10 +25,10 @@ from encoder.common import ( is_bf16_supported, get_quantization_recipe_from_name_string, - hf_login_if_available, + unpack_cached_datasets_if_available, ) -hf_login_if_available() +unpack_cached_datasets_if_available() IMAGE_H = 28 IMAGE_W = 28 From f3b97c26b58212b96b0580bc884ddc263a1ea1c2 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Thu, 6 Nov 2025 11:47:17 -0800 Subject: [PATCH 104/141] Fix out of bounds access in the FP4 dequantize kernel (#2346) Signed-off-by: Przemek Tredak --- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 4 +++ .../tensor/storage/nvfp4_tensor_storage.py | 33 ++----------------- 2 files changed, 7 insertions(+), 30 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index bf7b535be..5307cad37 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -39,6 +39,10 @@ __global__ void __launch_bounds__(512) const size_t x = thread_idx % M; const size_t y = thread_idx / M; + if (y >= N) { + return; + } + union fp4vec { uint64_t vec; fp4e2m1x4 small_vec[4]; diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 67543a8e2..04ab092ee 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -13,12 +13,12 @@ import torch -# import transformer_engine_torch as tex +import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from ...quantized_tensor import QuantizedTensorStorage, Quantizer -# from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...constants import TE_DType as torch_to_transformer_engine_dtype from ...utils import _empty_tensor @@ -45,34 +45,7 @@ def forward( # Dequantize row-wise data if tensor._rowwise_data is not None: - ### TODO(tmoon): Debug dequantize kernel and remove unfused impl - # return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) - - # Tensor properties - shape = list(tensor._rowwise_data.size()) - shape[-1] *= 2 - device = tensor._rowwise_data.device - - # Convert FP4E2M1 values to FP32 - data = tensor._rowwise_data.view(torch.uint8).to(torch.int32) - data = torch.stack((data & 0x0F, data >> 4), dim=-1).reshape(shape) - data = _fp4_e2m1_vals(device, dtype=torch.float32)[data] - data = data.to(torch.float32).contiguous() - - # Convert FP8E4M3 block scales to FP32 - block_scales = tensor._rowwise_scale_inv - block_scales = block_scales.reshape(-1, block_scales.size(-1)) - block_scales = block_scales[: math.prod(shape[:-1]), : shape[-1] // 16] - block_scales = block_scales.view(torch.float8_e4m3fn).to(torch.float32) - - # Convert amax to FP32 tensor scale - tensor_scale = tensor._amax_rowwise / (6.0 * 448.0) # Scale by FP4E2M1 and FP8E4M3 max - - # Apply scales - block_data = data.view(-1, 16) - block_data *= tensor_scale.view(()) * block_scales.reshape(-1, 1) - - return data.to(dtype) + return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) if tensor._columnwise_data is not None: raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") From b14a3b62cd58ab800d55daf42337b36514e943c7 Mon Sep 17 00:00:00 2001 From: Kunlun Li <94586211+kunlunl@users.noreply.github.com> Date: Fri, 7 Nov 2025 05:19:46 +0800 Subject: [PATCH 105/141] Make FP8 weights compatible with older MCore version (#2342) * Make cast_master_weights_to_fp8 compatible with older MCore version Signed-off-by: kunlunl * Rename keep_columnwise to manual_post_all_gather_processing & Optimize unit test Signed-off-by: kunlunl * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove redundant _test_mini_optimizer() Signed-off-by: kunlunl --------- Signed-off-by: kunlunl Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../run_cast_master_weights_to_fp8.py | 690 ---------------- .../test_cast_master_weights_to_fp8.py | 751 +++++++++++++++++- transformer_engine/pytorch/tensor/utils.py | 56 +- 3 files changed, 771 insertions(+), 726 deletions(-) delete mode 100644 tests/pytorch/distributed/run_cast_master_weights_to_fp8.py diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py deleted file mode 100644 index 2f11a24ee..000000000 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ /dev/null @@ -1,690 +0,0 @@ -#!/usr/bin/python3 - -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import argparse -import datetime -import os -import sys - -import torch -from torch import nn -import torch.distributed as dist - -from transformer_engine.common.recipe import ( - DelayedScaling, - Float8CurrentScaling, - Float8BlockScaling, - Format, - Recipe, -) -import transformer_engine.pytorch as te -from transformer_engine.pytorch import ( - QuantizedTensor, - Float8Tensor, - Float8BlockwiseQTensor, -) -from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 -from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data - - -def _get_raw_data(quantized_tensor): - """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" - if isinstance(quantized_tensor, Float8Tensor): - assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" - assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" - return quantized_tensor._data - elif isinstance(quantized_tensor, Float8BlockwiseQTensor): - assert hasattr( - quantized_tensor, "_rowwise_data" - ), "Float8BlockwiseQTensor does not have _rowwise_data attribute" - assert ( - quantized_tensor._rowwise_data.dtype == torch.uint8 - ), "Float8BlockwiseQTensor _rowwise_data must be uint8" - return quantized_tensor._rowwise_data - else: - raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") - - -class MiniZero_1: - """A mini zero-1 optimizer implementation, just used for this test""" - - def __init__(self, weights, lr, dp_group): - self.rank = dist.get_rank(dp_group) - self.world_size = dist.get_world_size(dp_group) - - self.weights = weights - self.lr = lr - self.dp_group = dp_group - - # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer - self.offsets = [0] - for weight in self.weights: - self.offsets.append(self.offsets[-1] + weight.numel()) - - # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may - # not be the end range of the last weight. - if self.offsets[-1] % self.world_size != 0: - self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size - - self.master_weights = [] - # The start offset of the master weight in the weight - self.start_offsets = [] - # The overlapping area of the weight and this rank's local buffer - self.overlapping_areas = [] - - # The start and end of this rank's local buffer in the global buffer - rank_start = self.offsets[-1] // self.world_size * self.rank - rank_end = rank_start + self.offsets[-1] // self.world_size - - for weight, offset in zip(self.weights, self.offsets[:-1]): - if offset >= rank_end or (offset + weight.numel()) <= rank_start: - # This weight is not in this rank's local buffer - master_weight = None - start_offset = None - overlapping_area = None - else: - overlapping_start = max(rank_start, offset) - overlapping_end = min(rank_end, offset + weight.numel()) - length = overlapping_end - overlapping_start - start_offset = overlapping_start - offset - if isinstance(weight, QuantizedTensor): - # If weight is a FP8 tensor, we need to use the original high precision version - # to initialize the master weight. - high_precision_init_val = weight.get_high_precision_init_val().view(-1) - master_weight = high_precision_init_val.to(weight.device).float()[ - start_offset : start_offset + length - ] - else: - master_weight = ( - weight.detach().view(-1).float()[start_offset : start_offset + length] - ) - overlapping_area = (overlapping_start, overlapping_end) - self.master_weights.append(master_weight) - self.start_offsets.append(start_offset) - self.overlapping_areas.append(overlapping_area) - - # Create global buffer for grads reduce-scatter - self.grad_buffer = torch.empty( - [self.offsets[-1]], dtype=torch.float32, device=weights[0].device - ) - self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end] - - # Create global buffer for weights all-gather - if isinstance(self.weights[0], QuantizedTensor): - weight_buffer_dtype = torch.uint8 - else: - weight_buffer_dtype = weights[0].dtype - self.weight_buffer = torch.empty( - [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device - ) - self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] - - def step(self): - # ----------------------------------------------------------------------------------------- - # Step 1: Copy grads to the grad buffer - # ----------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - self.grad_buffer[start:end].copy_(weight.main_grad.view(-1)) - - # ----------------------------------------------------------------------------------------- - # Step 2: Grads reduce-scatter - # ----------------------------------------------------------------------------------------- - # Don't use reduce_scatter directly to explicitly control the reduce order. - # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG, - # group=self.dp_group) - buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)] - dist.all_gather(buffers, self.grad_buffer, group=self.dp_group) - for i in range(1, self.world_size): - buffers[0] += buffers[i] - rank_start = self.offsets[-1] // self.world_size * self.rank - rank_end = rank_start + self.offsets[-1] // self.world_size - self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end]) - self.grad_buffer_slice /= self.world_size - - # ----------------------------------------------------------------------------------------- - # Step 3: Update master weights - # ----------------------------------------------------------------------------------------- - for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas): - if master_weight is None: - # This weight's master weight is in other rank. - continue - grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]] - master_weight -= grad * self.lr - - # ----------------------------------------------------------------------------------------- - # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight - # ----------------------------------------------------------------------------------------- - if isinstance(self.weights[0], QuantizedTensor): - # FP8 weights case - for i in range(1, len(self.weights)): - assert isinstance(self.weights[i], QuantizedTensor) - cast_master_weights_to_fp8( - self.weights, self.master_weights, self.start_offsets, self.dp_group - ) - else: - # BF16 weights case - for weight, master_weight, start_offset in zip( - self.weights, self.master_weights, self.start_offsets - ): - if master_weight is None: - continue - start = start_offset - end = start_offset + master_weight.numel() - weight.data.view(-1)[start:end].copy_(master_weight) - - # ----------------------------------------------------------------------------------------- - # Step 5: Copy the updated weights (not all weights) to the weight buffer - # ----------------------------------------------------------------------------------------- - for i in range(len(self.weights)): - master_weight = self.master_weights[i] - if master_weight is None: - continue - start_offset = self.start_offsets[i] - if isinstance(self.weights[i], QuantizedTensor): - weight = _get_raw_data(self.weights[i]) - else: - weight = self.weights[i] - weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] - overlapping_start, overlapping_end = self.overlapping_areas[i] - self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) - - # ----------------------------------------------------------------------------------------- - # Step 6: Weight all-gather (FP8 or BF16) - # ----------------------------------------------------------------------------------------- - dist.all_gather_into_tensor( - self.weight_buffer, self.weight_buffer_slice, group=self.dp_group - ) - - # ----------------------------------------------------------------------------------------- - # Step 7: Copy the gathered weights from weight buffer to the actual weights - # ----------------------------------------------------------------------------------------- - quantized_weights = [] - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - if isinstance(weight, QuantizedTensor): - quantized_weights.append(weight) - weight = _get_raw_data(weight) - weight.view(-1).data.copy_(self.weight_buffer[start:end]) - post_all_gather_processing(quantized_weights) - - -class MiniOptimizer: - - def __init__(self, weights, lr, dp_group): - self.world_size = dist.get_world_size(dp_group) - - self.weights = weights - self.lr = lr - self.dp_group = dp_group - - master_weights = [] - for weight in self.weights: - master_weights.append(weight.detach().float()) - self.master_weights = master_weights - - def step(self): - for weight, master_weight in zip(self.weights, self.master_weights): - main_grad = weight.main_grad - - # Don't use all-reduce directly to explicitly control the reduce order. - # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group) - buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)] - dist.all_gather(buffers, main_grad, group=self.dp_group) - for i in range(1, self.world_size): - buffers[0] += buffers[i] - main_grad.copy_(buffers[0]) - main_grad /= self.world_size - - master_weight -= main_grad * self.lr - weight.data.copy_(master_weight) - - -class MiniFSDP: - def __init__(self, weights, lr, dp_group): - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - self.weights = weights - self.lr = lr - self.dp_group = dp_group - - # Flatten the weights and pad to align with world size - if isinstance(weights[0], QuantizedTensor): - raw_data_list = [_get_raw_data(w).view(-1) for w in weights] - else: - raw_data_list = [w.view(-1) for w in weights] - self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) - - # Split flattened weights into shards - self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] - self.local_main_grad_shard = torch.zeros_like( - self.local_weight_shard, dtype=torch.float32, device="cuda" - ) - shard_size = self.flatten_weight.size(0) // world_size - - # Map original tensors to flattened indices - tensor_indices = [] - cumulative_length = 0 - for tensor in raw_data_list: - length = tensor.size(0) - tensor_indices.append((cumulative_length, cumulative_length + length)) - cumulative_length += length - - # Build shard index mappings - self.weight_indices = [] - self.shard_indices = [] - for idx, (start, end) in enumerate(tensor_indices): - shard_start = rank * shard_size - shard_end = shard_start + shard_size - adjusted_end = min(shard_end, original_length) - - if start <= adjusted_end and end >= shard_start: - start_idx = max(start, shard_start) - end_idx = min(end, adjusted_end) - self.weight_indices.append((start_idx - start, end_idx - start)) - self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) - else: - self.weight_indices.append((None, None)) - self.shard_indices.append((None, None)) - - if isinstance(weights[idx], QuantizedTensor): - replace_raw_data( - weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) - ) - else: - weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) - - # Initialize local model weights and high-precision master weights - self.local_weights = [] - self.master_weights = [] - for i, weight in enumerate(self.weights): - weight_start, weight_end = self.weight_indices[i] - shard_start, shard_end = self.shard_indices[i] - if shard_start is not None and shard_end is not None: - local_weight_shard = self.local_weight_shard[shard_start:shard_end] - self.local_weights.append(local_weight_shard) - - if isinstance(weight, QuantizedTensor): - high_precision_init_val = weight.get_high_precision_init_val().view(-1) - master_weight_shard = high_precision_init_val.to(weight.device).float()[ - weight_start:weight_end - ] - else: - master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] - self.master_weights.append(master_weight_shard) - else: - self.local_weights.append(None) - self.master_weights.append(None) - setattr( - weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") - ) - - def _flatten_tensors_with_pad(self, tensors): - """ - Flatten the list of tensors and pad them to align with the world size. - - Args: - tensors (list): List of tensors to flatten. - - Returns: - tuple: Flattened tensor and its original length before padding. - """ - world_size = dist.get_world_size(self.dp_group) - - flatten_tensor = torch.cat(tensors) - original_length = flatten_tensor.size(0) - - padding_needed = (world_size - original_length % world_size) % world_size - if padding_needed > 0: - zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda") - flatten_tensor = torch.cat([flatten_tensor, zeros]) - - return flatten_tensor, original_length - - def zero_grad(self): - for weight in self.weights: - weight.grad = None - weight.main_grad.zero_() - - def step(self): - """ - Perform an optimization step for the distributed sharded model. - - This method includes: - 1. Gradient reduce-scatter: Synchronize gradients across all processes. - 2. Master weight update: Update high-precision master weights using local gradients. - 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. - 4. Weight synchronization: All-gather updated weights across all processes. - - Returns: - None - """ - # Step 1: Reduce-scatter the gradients - main_grad_buffer, _ = self._flatten_tensors_with_pad( - [weight.main_grad.view(-1) for weight in self.weights] - ) - dist.reduce_scatter_tensor( - self.local_main_grad_shard, main_grad_buffer, group=self.dp_group - ) - self.local_main_grad_shard /= dist.get_world_size(self.dp_group) - - # Step 2: Update the master weights - for weight, master_weight, (shard_start, shard_end) in zip( - self.weights, self.master_weights, self.shard_indices - ): - if master_weight is None: - continue - - # Extract the local gradient shard for this weight - grad = self.local_main_grad_shard[shard_start:shard_end] - - # Update the master weight using gradient descent - master_weight -= grad * self.lr - - # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], QuantizedTensor): - local_weights = [] - for local_weight in self.local_weights: - if local_weight is None: - local_weights.append(None) - continue - - local_weights.append(local_weight) - - cast_master_weights_to_fp8( - self.weights, - self.master_weights, - [idx[0] for idx in self.weight_indices], - self.dp_group, - local_weights, - ) - else: - for weight, master_weight in zip(self.local_weights, self.master_weights): - if master_weight is None: - continue - - # Copy updated master weights to local weights - weight.data.copy_(master_weight) - - # Step 4: All-gather updated weights across processes - dist.all_gather_into_tensor( - self.flatten_weight, self.local_weight_shard, group=self.dp_group - ) - quantized_weights = [] - for weight in self.weights: - if isinstance(weight, QuantizedTensor): - quantized_weights.append(weight) - post_all_gather_processing(quantized_weights) - - -def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - # Configuration constants - NUM_STEPS = 100 - SEED = 12345 - - torch.manual_seed(SEED) - torch.cuda.manual_seed(SEED) - - mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] - mock_group = mock_groups[rank] - - linear_kwargs = { - "params_dtype": torch.bfloat16, - "bias": False, - "fuse_wgrad_accumulation": True, - } - - # Create model with FP8 weights - with te.quantized_model_init( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - preserve_high_precision_init_val=True, - ): - model_fp8 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Create model with BF16 weights - model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Make sure the BF16 model and FP8 model have the same initial weights - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - high_precision_init_val = w_fp8.get_high_precision_init_val() - w.data.copy_(high_precision_init_val) - - optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group) - optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) - - for _ in range(100): - optimizer_fp8.zero_grad() - optimizer.zero_grad() - - inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) - ] - # Choose based on rank to make sure the inputs of different ranks are different. - x = inputs[rank] - - with te.autocast( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - amax_reduction_group=mock_group, - ): - y_fp8 = model_fp8(x) - - with te.autocast( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - amax_reduction_group=mock_group, - ): - y = model(x) - - targets = [torch.randn_like(y) for _ in range(world_size)] - # Choose based on rank to make sure the targets of different ranks are different. - target = targets[rank] - loss_fp8 = nn.MSELoss()(y_fp8, target) - loss = nn.MSELoss()(y, target) - - loss_fp8.backward() - loss.backward() - - optimizer_fp8.step() - optimizer.step() - - torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) - - -def _test_mini_optimizer(dp_group): - """Make sure the implementation of MiniZero_1 and MiniFSDP is correct""" - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - torch.manual_seed(12345) - torch.cuda.manual_seed(12345) - - weights = [ - torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"), - torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"), - torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"), - ] - - weights_1 = weights - weights_2 = [weight.clone() for weight in weights] - weights_3 = [weight.clone() for weight in weights] - - lr = 1.0 - optimizer_1 = MiniZero_1(weights_1, lr, dp_group) - optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) - optimizer_3 = MiniFSDP(weights_3, lr, dp_group) - - for _ in range(100): - for w1, w2, w3 in zip(weights_1, weights_2, weights_3): - main_grads = [ - torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) - ] - # Choose based on rank to make sure the grads of different ranks are different. - main_grad = main_grads[rank] - w1.main_grad = main_grad - w2.main_grad = main_grad - w3.main_grad = main_grad - - optimizer_1.step() - optimizer_2.step() - optimizer_3.step() - - for w1, w2 in zip(weights_1, weights_2): - torch.testing.assert_close(w1, w2, atol=0, rtol=0) - for w1, w3 in zip(weights_1, weights_3): - torch.testing.assert_close(w1, w3, atol=0, rtol=0) - - -def quantization_recipe(quantization) -> Recipe: - """Quantization recipe setup""" - fp8_format = Format.HYBRID - if quantization == "fp8": - return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - elif quantization == "fp8_cs": - return Float8CurrentScaling(fp8_format=fp8_format) - elif quantization == "fp8_block": - return Float8BlockScaling(fp8_format=fp8_format) - else: - raise ValueError(f"Unsupported quantization: {quantization}") - - -def _test_cast_master_weights_to_fp8(quantization, dp_group): - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - torch.manual_seed(12345) - torch.cuda.manual_seed(12345) - - mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] - mock_group = mock_groups[rank] - - linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} - - # Create model with FP8 weights - with te.quantized_model_init( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - preserve_high_precision_init_val=True, - ): - model_fp8 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Create model with BF16 weights - model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Make sure the BF16 model and FP8 model have the same initial weights - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - high_precision_init_val = w_fp8.get_high_precision_init_val() - w.data.copy_(high_precision_init_val) - - # Allocate main_grads for each weight - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda") - w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") - - optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group) - optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) - - for i in range(100): - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - w_fp8.main_grad.zero_() - w.main_grad.zero_() - - inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) - ] - # Choose based on rank to make sure the inputs of different ranks are different. - x = inputs[rank] - - with te.autocast( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - amax_reduction_group=mock_group, - ): - y_fp8 = model_fp8(x) - - with te.autocast( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - amax_reduction_group=mock_group, - ): - y = model(x) - - targets = [torch.randn_like(y) for _ in range(world_size)] - # Choose based on rank to make sure the targets of different ranks are different. - target = targets[rank] - loss_fp8 = nn.MSELoss()(y_fp8, target) - loss = nn.MSELoss()(y, target) - - loss_fp8.backward() - loss.backward() - - optimizer_fp8.step() - optimizer.step() - - torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) - - -def main(argv=None, namespace=None): - WORLD_RANK = int(os.getenv("RANK", "0")) - WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - - assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node - assert LOCAL_SIZE <= torch.cuda.device_count() - dist_init_kwargs = { - "backend": "nccl", - "rank": WORLD_RANK, - "world_size": WORLD_SIZE, - "timeout": datetime.timedelta(seconds=30), - } - dist_init_kwargs["init_method"] = "env://" - dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") - assert dist.is_nccl_available() - torch.cuda.set_device(LOCAL_RANK) - dist.init_process_group(**dist_init_kwargs) - - parser = argparse.ArgumentParser() - parser.add_argument( - "--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"] - ) - args = parser.parse_args(argv, namespace) - - dp_group = dist.new_group(backend="nccl") - _test_mini_optimizer(dp_group) - _test_cast_master_weights_to_fp8(args.quantization, dp_group) - _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) - - dist.destroy_process_group() - return 0 - - -if __name__ == "__main__": - - sys.exit(main()) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index 5bf46b8d5..0ff98e6cb 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -2,39 +2,744 @@ # # See LICENSE for license information. +import argparse +import datetime import os import subprocess -from pathlib import Path +import sys +import pathlib import pytest import torch -from transformer_engine.pytorch import is_fp8_available, is_fp8_block_scaling_available +from torch import nn +import torch.distributed as dist +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Float8BlockScaling, + Format, + Recipe, +) +import transformer_engine.pytorch as te +from transformer_engine.pytorch import ( + is_fp8_available, + is_fp8_block_scaling_available, + QuantizedTensor, + Float8Tensor, + Float8BlockwiseQTensor, +) +from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data -if torch.cuda.device_count() < 2: - pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( - return_reason=True -) +def _get_quantization_recipe(quantization) -> Recipe: + """Quantization recipe setup""" + fp8_format = Format.HYBRID + if quantization == "fp8": + return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + elif quantization == "fp8_cs": + return Float8CurrentScaling(fp8_format=fp8_format) + elif quantization == "fp8_block": + return Float8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError(f"Unsupported quantization: {quantization}") + + +def _get_raw_data(quantized_tensor): + """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" + if isinstance(quantized_tensor, Float8Tensor): + assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" + assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" + return quantized_tensor._data + elif isinstance(quantized_tensor, Float8BlockwiseQTensor): + assert hasattr( + quantized_tensor, "_rowwise_data" + ), "Float8BlockwiseQTensor does not have _rowwise_data attribute" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "Float8BlockwiseQTensor _rowwise_data must be uint8" + return quantized_tensor._rowwise_data + else: + raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") + + +class MiniOptimizer: + + def __init__(self, weights, lr, dp_group): + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + master_weights = [] + for weight in self.weights: + master_weights.append(weight.detach().float()) + self.master_weights = master_weights + + def step(self): + for weight, master_weight in zip(self.weights, self.master_weights): + main_grad = weight.main_grad + + # Don't use all-reduce directly to explicitly control the reduce order. + # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group) + buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)] + dist.all_gather(buffers, main_grad, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + main_grad.copy_(buffers[0]) + main_grad /= self.world_size + + master_weight -= main_grad * self.lr + weight.data.copy_(master_weight) + + +class MiniZero_1: + """A mini zero-1 optimizer implementation, just used for this test""" + + def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=False): + self.rank = dist.get_rank(dp_group) + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + self.manual_post_all_gather_processing = manual_post_all_gather_processing + + # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer + self.offsets = [0] + for weight in self.weights: + self.offsets.append(self.offsets[-1] + weight.numel()) + + # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may + # not be the end range of the last weight. + if self.offsets[-1] % self.world_size != 0: + self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + + self.master_weights = [] + # The start offset of the master weight in the weight + self.start_offsets = [] + # The overlapping area of the weight and this rank's local buffer + self.overlapping_areas = [] + + # The start and end of this rank's local buffer in the global buffer + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + + for weight, offset in zip(self.weights, self.offsets[:-1]): + if offset >= rank_end or (offset + weight.numel()) <= rank_start: + # This weight is not in this rank's local buffer + master_weight = None + start_offset = None + overlapping_area = None + else: + overlapping_start = max(rank_start, offset) + overlapping_end = min(rank_end, offset + weight.numel()) + length = overlapping_end - overlapping_start + start_offset = overlapping_start - offset + if isinstance(weight, QuantizedTensor): + # If weight is a FP8 tensor, we need to use the original high precision version + # to initialize the master weight. + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight = high_precision_init_val.to(weight.device).float()[ + start_offset : start_offset + length + ] + else: + master_weight = ( + weight.detach().view(-1).float()[start_offset : start_offset + length] + ) + overlapping_area = (overlapping_start, overlapping_end) + self.master_weights.append(master_weight) + self.start_offsets.append(start_offset) + self.overlapping_areas.append(overlapping_area) + + # Create global buffer for grads reduce-scatter + self.grad_buffer = torch.empty( + [self.offsets[-1]], dtype=torch.float32, device=weights[0].device + ) + self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end] + + # Create global buffer for weights all-gather + if isinstance(self.weights[0], QuantizedTensor): + weight_buffer_dtype = torch.uint8 + else: + weight_buffer_dtype = weights[0].dtype + self.weight_buffer = torch.empty( + [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + ) + self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + + def step(self): + # ----------------------------------------------------------------------------------------- + # Step 1: Copy grads to the grad buffer + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + self.grad_buffer[start:end].copy_(weight.main_grad.view(-1)) + + # ----------------------------------------------------------------------------------------- + # Step 2: Grads reduce-scatter + # ----------------------------------------------------------------------------------------- + # Don't use reduce_scatter directly to explicitly control the reduce order. + # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG, + # group=self.dp_group) + buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)] + dist.all_gather(buffers, self.grad_buffer, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end]) + self.grad_buffer_slice /= self.world_size + + # ----------------------------------------------------------------------------------------- + # Step 3: Update master weights + # ----------------------------------------------------------------------------------------- + for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas): + if master_weight is None: + # This weight's master weight is in other rank. + continue + grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]] + master_weight -= grad * self.lr + + # ----------------------------------------------------------------------------------------- + # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight + # ----------------------------------------------------------------------------------------- + if isinstance(self.weights[0], QuantizedTensor): + # FP8 weights case + for i in range(1, len(self.weights)): + assert isinstance(self.weights[i], QuantizedTensor) + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + self.start_offsets, + self.dp_group, + manual_post_all_gather_processing=self.manual_post_all_gather_processing, + ) + else: + # BF16 weights case + for weight, master_weight, start_offset in zip( + self.weights, self.master_weights, self.start_offsets + ): + if master_weight is None: + continue + start = start_offset + end = start_offset + master_weight.numel() + weight.data.view(-1)[start:end].copy_(master_weight) + + # ----------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ----------------------------------------------------------------------------------------- + for i in range(len(self.weights)): + master_weight = self.master_weights[i] + if master_weight is None: + continue + start_offset = self.start_offsets[i] + if isinstance(self.weights[i], QuantizedTensor): + weight = _get_raw_data(self.weights[i]) + else: + weight = self.weights[i] + weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] + overlapping_start, overlapping_end = self.overlapping_areas[i] + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + + # ----------------------------------------------------------------------------------------- + # Step 6: Weight all-gather (FP8 or BF16) + # ----------------------------------------------------------------------------------------- + dist.all_gather_into_tensor( + self.weight_buffer, self.weight_buffer_slice, group=self.dp_group + ) + + # ----------------------------------------------------------------------------------------- + # Step 7: Copy the gathered weights from weight buffer to the actual weights + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) + + if self.manual_post_all_gather_processing: + quantized_weights = [ + weight for weight in self.weights if isinstance(weight, QuantizedTensor) + ] + post_all_gather_processing(quantized_weights) + + +class MiniFSDP: + def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=False): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + self.manual_post_all_gather_processing = manual_post_all_gather_processing + + # Flatten the weights and pad to align with world size + if isinstance(weights[0], QuantizedTensor): + raw_data_list = [_get_raw_data(w).view(-1) for w in weights] + else: + raw_data_list = [w.view(-1) for w in weights] + self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) + + # Split flattened weights into shards + self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] + self.local_main_grad_shard = torch.zeros_like( + self.local_weight_shard, dtype=torch.float32, device="cuda" + ) + shard_size = self.flatten_weight.size(0) // world_size + + # Map original tensors to flattened indices + tensor_indices = [] + cumulative_length = 0 + for tensor in raw_data_list: + length = tensor.size(0) + tensor_indices.append((cumulative_length, cumulative_length + length)) + cumulative_length += length + + # Build shard index mappings + self.weight_indices = [] + self.shard_indices = [] + for idx, (start, end) in enumerate(tensor_indices): + shard_start = rank * shard_size + shard_end = shard_start + shard_size + adjusted_end = min(shard_end, original_length) + + if start <= adjusted_end and end >= shard_start: + start_idx = max(start, shard_start) + end_idx = min(end, adjusted_end) + self.weight_indices.append((start_idx - start, end_idx - start)) + self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) + else: + self.weight_indices.append((None, None)) + self.shard_indices.append((None, None)) + + if isinstance(weights[idx], QuantizedTensor): + replace_raw_data( + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) + ) + else: + weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) + + # Initialize local model weights and high-precision master weights + self.local_weights = [] + self.master_weights = [] + for i, weight in enumerate(self.weights): + weight_start, weight_end = self.weight_indices[i] + shard_start, shard_end = self.shard_indices[i] + if shard_start is not None and shard_end is not None: + local_weight_shard = self.local_weight_shard[shard_start:shard_end] + self.local_weights.append(local_weight_shard) + + if isinstance(weight, QuantizedTensor): + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight_shard = high_precision_init_val.to(weight.device).float()[ + weight_start:weight_end + ] + else: + master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] + self.master_weights.append(master_weight_shard) + else: + self.local_weights.append(None) + self.master_weights.append(None) + setattr( + weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") + ) + + def _flatten_tensors_with_pad(self, tensors): + """ + Flatten the list of tensors and pad them to align with the world size. + + Args: + tensors (list): List of tensors to flatten. + + Returns: + tuple: Flattened tensor and its original length before padding. + """ + world_size = dist.get_world_size(self.dp_group) + + flatten_tensor = torch.cat(tensors) + original_length = flatten_tensor.size(0) + + padding_needed = (world_size - original_length % world_size) % world_size + if padding_needed > 0: + zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda") + flatten_tensor = torch.cat([flatten_tensor, zeros]) + + return flatten_tensor, original_length + + def zero_grad(self): + for weight in self.weights: + weight.grad = None + weight.main_grad.zero_() + + def step(self): + """ + Perform an optimization step for the distributed sharded model. + + This method includes: + 1. Gradient reduce-scatter: Synchronize gradients across all processes. + 2. Master weight update: Update high-precision master weights using local gradients. + 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. + 4. Weight synchronization: All-gather updated weights across all processes. + + Returns: + None + """ + # Step 1: Reduce-scatter the gradients + main_grad_buffer, _ = self._flatten_tensors_with_pad( + [weight.main_grad.view(-1) for weight in self.weights] + ) + dist.reduce_scatter_tensor( + self.local_main_grad_shard, main_grad_buffer, group=self.dp_group + ) + self.local_main_grad_shard /= dist.get_world_size(self.dp_group) + + # Step 2: Update the master weights + for weight, master_weight, (shard_start, shard_end) in zip( + self.weights, self.master_weights, self.shard_indices + ): + if master_weight is None: + continue + + # Extract the local gradient shard for this weight + grad = self.local_main_grad_shard[shard_start:shard_end] + + # Update the master weight using gradient descent + master_weight -= grad * self.lr + + # Step 3: Cast master weights to FP8 or BF16 precision + if isinstance(self.weights[0], QuantizedTensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + + local_weights.append(local_weight) + + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + manual_post_all_gather_processing=self.manual_post_all_gather_processing, + ) + else: + for weight, master_weight in zip(self.local_weights, self.master_weights): + if master_weight is None: + continue -TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS: int = min(2, torch.cuda.device_count()) -LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + # Copy updated master weights to local weights + weight.data.copy_(master_weight) + + # Step 4: All-gather updated weights across processes + dist.all_gather_into_tensor( + self.flatten_weight, self.local_weight_shard, group=self.dp_group + ) + + if self.manual_post_all_gather_processing: + quantized_weights = [ + weight for weight in self.weights if isinstance(weight, QuantizedTensor) + ] + post_all_gather_processing(quantized_weights) + + +def _test_mini_optimizer(dp_group): + """Make sure the implementation of MiniZero_1 and MiniFSDP is correct""" + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + weights = [ + torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"), + ] + + weights_1 = weights + weights_2 = [weight.clone() for weight in weights] + weights_3 = [weight.clone() for weight in weights] + + lr = 1.0 + optimizer_1 = MiniZero_1(weights_1, lr, dp_group) + optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) + optimizer_3 = MiniFSDP(weights_3, lr, dp_group) + + for _ in range(100): + for w1, w2, w3 in zip(weights_1, weights_2, weights_3): + main_grads = [ + torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the grads of different ranks are different. + main_grad = main_grads[rank] + w1.main_grad = main_grad + w2.main_grad = main_grad + w3.main_grad = main_grad + + optimizer_1.step() + optimizer_2.step() + optimizer_3.step() + + for w1, w2 in zip(weights_1, weights_2): + torch.testing.assert_close(w1, w2, atol=0, rtol=0) + for w1, w3 in zip(weights_1, weights_3): + torch.testing.assert_close(w1, w3, atol=0, rtol=0) + + +def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gather_processing): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + + # Create model with FP8 weights + with te.quantized_model_init( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + # Allocate main_grads for each weight + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_fp8 = MiniZero_1( + [w for w in model_fp8.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for i in range(100): + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + +def _test_fsdp_cast_master_weights_to_fp8( + quantization, dp_group, manual_post_all_gather_processing +): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + # Configuration constants + NUM_STEPS = 100 + SEED = 12345 + + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = { + "params_dtype": torch.bfloat16, + "bias": False, + "fuse_wgrad_accumulation": True, + } + + # Create model with FP8 weights + with te.quantized_model_init( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + optimizer_fp8 = MiniFSDP( + [w for w in model_fp8.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + optimizer_fp8.zero_grad() + optimizer.zero_grad() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + +def run_parallel_tests() -> None: + """Run parallel tests""" + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + dp_group = dist.new_group(backend="nccl") + + quantizations = [] + if is_fp8_available(): + quantizations.extend(["fp8", "fp8_cs"]) + if is_fp8_block_scaling_available(): + quantizations.append("fp8_block") + + manual_post_all_gather_processings = [False, True] + + _test_mini_optimizer(dp_group) + + for quantization in quantizations: + for post_ag_processing in manual_post_all_gather_processings: + _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + + dist.destroy_process_group() + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="cast_master_weights_to_fp8 test needs at least 2 GPUs." +) +@pytest.mark.parametrize("world_size", [2]) +def test_cast_master_weights_to_fp8(world_size: int) -> None: + """Launch parallel job that runs parallel tests""" + python_exe = pathlib.Path(sys.executable).resolve() + current_file = pathlib.Path(__file__).resolve() + command = [ + python_exe, + "-m", + "torch.distributed.run", + f"--nproc_per_node={world_size}", + current_file, + "--parallel", + ] + result = subprocess.run( + command, + check=True, + ) -def _run_test(quantization): - test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py" - test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization] - result = subprocess.run(test_cmd, env=os.environ, check=False) - assert result.returncode == 0 +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + args = parser.parse_args() + if args.parallel: + run_parallel_tests() -@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"]) -def test_cast_master_weights_to_fp8(quantization): - if quantization in ("fp8", "fp8_cs") and not fp8_available: - pytest.skip(reason_for_no_fp8) - if quantization == "fp8_block" and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - _run_test(quantization) +if __name__ == "__main__": + main() diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 8354823b3..20aba6c2b 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -48,7 +48,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): def cast_master_weights_to_fp8( - model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights=None, + manual_post_all_gather_processing=False, ): r"""Helper function to cast master weights to FP8 primary weights. @@ -69,6 +74,11 @@ def cast_master_weights_to_fp8( fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are not sharded. Otherwise, it means that the model weights are sharded and we get target model weights data storage using the FSDP shard model weights. + manual_post_all_gather_processing: bool, default = `False`. + If False, post processing will be automatically triggered during next forward. + If True, the timing of calling post_all_gather_processing is left to the user. + Note that users must call `post_all_gather_processing` if it's set to True, + otherwise the weights won't be updated correctly. """ @@ -129,21 +139,18 @@ def cast_master_weights_to_fp8( f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" ) + extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: - _cast_master_weights_to_fp8_delayed_scaling( - delayed_scaling_params, group, use_fsdp_shard_model_weights - ) + _cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, *extra_args) if len(current_scaling_params) > 0: - _cast_master_weights_to_fp8_current_scaling( - current_scaling_params, group, use_fsdp_shard_model_weights - ) + _cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args) if len(blockwise_scaling_params) > 0: - _cast_master_weights_to_fp8_blockwise_scaling( - blockwise_scaling_params, group, use_fsdp_shard_model_weights - ) + _cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args) -def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): +def _cast_master_weights_to_fp8_delayed_scaling( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. Parameters @@ -160,6 +167,13 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo amaxes, scales, scale_invs = [], [], [] for model_weight, master_weight, start_offset, shard_model_weight_raw in params: + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight._reset_caches() + quantizer = model_weight._get_quantizer() amaxes.append(quantizer.amax.view(1)) @@ -219,7 +233,9 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo ) -def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False): +def _cast_master_weights_to_fp8_current_scaling( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): r"""Helper function to cast master weights to FP8 primary weights for current scaling. Parameters @@ -297,6 +313,13 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( params, scales ): + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight._reset_caches() + # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. if master_weight is None: @@ -322,7 +345,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo def _cast_master_weights_to_fp8_blockwise_scaling( - params, group, use_fsdp_shard_model_weights=False + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling. @@ -421,6 +444,13 @@ def _cast_master_weights_to_fp8_blockwise_scaling( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( params, scales ): + if not manual_post_all_gather_processing: + # Clear columnwise data for all model weights. + # We cannot create columnwise data here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated at this moment. + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. if master_weight is None: From 4ff3eed10acfa0ef88de3f80e6ab6349f9604523 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 6 Nov 2025 16:46:40 -0800 Subject: [PATCH 106/141] [JAX] Add test to check jaxpr that amax is reused for nvfp4 recipe (#2348) * Add test to check jaxpr that amax is reused for nvfp4 recipe Signed-off-by: Jeremy Berchtold * Move test to test_helper.py and rename file Signed-off-by: Jeremy Berchtold * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jeremy Berchtold Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_custom_call_compute.py | 1 - ...lper.py => test_recipe_characteristics.py} | 67 ++++++++++++++++++- 2 files changed, 66 insertions(+), 2 deletions(-) rename tests/jax/{test_helper.py => test_recipe_characteristics.py} (78%) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index cecdb3121..3d4f179ab 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -45,7 +45,6 @@ from transformer_engine.jax.activation import activation from transformer_engine.jax.dense import dense, grouped_dense from transformer_engine.jax.layernorm_dense import layernorm_dense -from transformer_engine.common import recipe GEMM_CASES = [ (256, 256, 512), diff --git a/tests/jax/test_helper.py b/tests/jax/test_recipe_characteristics.py similarity index 78% rename from tests/jax/test_helper.py rename to tests/jax/test_recipe_characteristics.py index fc88b7ef7..33fde7e23 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_recipe_characteristics.py @@ -11,7 +11,7 @@ import numpy as np from flax import linen as nn -from utils import assert_allclose +from utils import assert_allclose, pytest_parametrize_wrapper from transformer_engine.common.recipe import ( DelayedScaling, MXFP8BlockScaling, @@ -22,6 +22,7 @@ from transformer_engine.jax import autocast from transformer_engine.jax.quantize import ( get_quantize_config, + get_supported_quantization_recipes, is_scaling_mode_supported, ScalingMode, update_collections, @@ -32,11 +33,15 @@ from transformer_engine.jax.quantize.helper import _format2dtypes from transformer_engine.jax.sharding import MeshResource, global_mesh_resource from transformer_engine.jax.flax.module import TransformerEngineBase +from transformer_engine.jax import flax as te_flax +import transformer_engine.jax as te is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) +SUPPORTED_RECIPES = get_supported_quantization_recipes() + def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries.""" @@ -253,3 +258,63 @@ def test_autocast_nvfp4_block_scaling(self): self._compare_nvfp4_scaling_quantizers(bs) self._check_default_state() + + +class TestJaxprAndHlo: + """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations.""" + + @pytest_parametrize_wrapper( + "quantization_recipe", + [ + quantization_recipe + for quantization_recipe in SUPPORTED_RECIPES + if isinstance(quantization_recipe, NVFP4BlockScaling) + ], + ) + def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe): + """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton.""" + + with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()): + model = te_flax.LayerNormMLP( + layernorm_type="rmsnorm", + return_layernorm_output=False, + intermediate_dropout_rate=0.0, + dtype=jnp.bfloat16, + ) + + var_collect = model.init( + jax.random.PRNGKey(0), + jnp.ones((128, 128), dtype=jnp.bfloat16), + ) + + def loss_fn(x, rngs): + return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0]) + + x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16) + rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)} + jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs) + + rht_amax_eqns = [ + eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper" + ] + + assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}" + + def assert_param(index, tensor_name, expected_value: bool): + if expected_value: + assert rht_amax_eqns[index].params["produce_regular_amax"] == True, ( + f"Expected produce_regular_amax for {tensor_name} to be True, indicating no" + " reuse of amax as this tensor does not have a previous operation to fuse" + " with" + ) + else: + assert rht_amax_eqns[index].params["produce_regular_amax"] == False, ( + f"Expected produce_regular_amax for {tensor_name} to be False, indicating" + " reuse of amax" + ) + + assert_param(0, "fwd ln+q", False) + assert_param(1, "fwd act+q", False) + # No previous op before incoming dgrad in the backward so amax is not reused + assert_param(2, "bwd dgrad", True) + assert_param(3, "bwd dact+q", False) From f62cad90b1ed100b39bc01b2c6556b4033811714 Mon Sep 17 00:00:00 2001 From: Michael Goldfarb Date: Thu, 6 Nov 2025 20:04:40 -0600 Subject: [PATCH 107/141] Fix sharding of segment position to match id in ring attention. (#2349) --- .../jax/cpp_extensions/attention.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index c0cb6cda1..6a21480d8 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1784,6 +1784,9 @@ def partition(config, mesh, arg_infos, result_infos): ) arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[4] = seed_sharding + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -1991,7 +1994,13 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) helper = _FusedAttnCPWithP2PHelper(mesh, config) @@ -2265,6 +2274,9 @@ def partition(config, mesh, arg_infos, result_infos): ) arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[4] = seed_sharding + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -2403,7 +2415,11 @@ def partition(config, mesh, arg_infos, result_infos): if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) - arg_shardings = tuple(arg.sharding for arg in arg_infos) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) # dq, dk, dv, dbias sharding = q, k, v, bias sharding out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) From 26aad6b0faae88f4865fe6ace357b0f56485267e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 7 Nov 2025 12:00:31 -0500 Subject: [PATCH 108/141] Disable cuDNN attention for known IMA and NaNs (#2344) * Fix cuDNN backend selection for more case. Add CG as a option as well Signed-off-by: Kirthi Shankar Sivamani * fix logic Signed-off-by: Kirthi Shankar Sivamani * Fix cuDNN checks Signed-off-by: Kirthi Shankar Sivamani * Add more checks Signed-off-by: Kirthi Shankar Sivamani * Fix cuddn version Signed-off-by: Kirthi Shankar Sivamani * Fix error message Signed-off-by: Kirthi Shankar Sivamani * Add check for window size Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- .../common/fused_attn/fused_attn.cpp | 109 +++++++++++------- .../include/transformer_engine/fused_attn.h | 56 +++++---- .../jax/csrc/extensions/attention.cpp | 54 ++++----- .../dot_product_attention/backends.py | 5 + .../dot_product_attention/context_parallel.py | 8 ++ .../dot_product_attention.py | 1 + .../attention/dot_product_attention/utils.py | 5 + .../pytorch/cpp_extensions/fused_attn.py | 8 ++ transformer_engine/pytorch/csrc/extensions.h | 6 +- .../pytorch/csrc/extensions/attention.cpp | 42 +++---- 10 files changed, 178 insertions(+), 116 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index f6ee37d4c..9c6e9b33d 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -138,7 +138,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -166,7 +166,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || @@ -407,6 +407,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( " Please upgrade your cuDNN version if possible." << std::endl; } + if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-causal) and " + "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if ((cudnn_runtime_version <= 91500) && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-padding)," + " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" + " backward fused attention with graph capture requires cuDNN 9.15.1+. " + "Please upgrade your cuDNN version if possible." + << std::endl; + } } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -419,11 +441,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, bool return_max_logit, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); using namespace transformer_engine; @@ -460,7 +482,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, - h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit); + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -496,16 +519,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } } // NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -544,7 +565,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false); + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -602,10 +623,10 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream) { + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -681,7 +702,7 @@ void nvte_fused_attn_fwd_kvpacked( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, - return_max_logit); + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -728,7 +749,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream) { + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -776,9 +798,10 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false); + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -833,16 +856,19 @@ void nvte_fused_attn_bwd_kvpacked( } } // NVTE fused attention FWD with separate Q, K and V -void nvte_fused_attn_fwd( - const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, - const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, + cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -913,7 +939,7 @@ void nvte_fused_attn_fwd( NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, - return_max_logit); + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -963,7 +989,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, - NVTETensor workspace, cudaStream_t stream) { + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -1008,7 +1034,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, - h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false); + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 518fad20d..40e6a0b4b 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -207,13 +207,14 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit); + int64_t window_size_right, bool return_max_logit, bool cuda_graph); /*! \brief Compute dot product attention with packed QKV input. * @@ -257,6 +258,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -273,11 +275,11 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen, bool is_training, bool return_max_logit, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream); + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * @@ -324,19 +326,18 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, NVTETensor dSoftmaxOffset, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - size_t max_seqlen, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream); +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. * @@ -387,6 +388,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. @@ -405,10 +407,10 @@ void nvte_fused_attn_fwd_kvpacked( const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, - size_t max_seqlen_kv, bool is_training, bool return_max_logit, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - NVTETensor workspace, cudaStream_t stream); + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * @@ -461,6 +463,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -472,7 +475,8 @@ void nvte_fused_attn_bwd_kvpacked( const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, cudaStream_t stream); + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, + cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. * @@ -527,6 +531,7 @@ void nvte_fused_attn_bwd_kvpacked( * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. @@ -545,9 +550,9 @@ void nvte_fused_attn_fwd( const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -605,6 +610,7 @@ void nvte_fused_attn_fwd( * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ @@ -619,7 +625,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool deterministic, - NVTETensor workspace, cudaStream_t stream); + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index ffc0706fe..a99f4fae9 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -23,7 +23,7 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false); + false, false); return backend; } @@ -180,7 +180,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { @@ -189,7 +189,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { @@ -199,7 +199,7 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, qkv_layout, + kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } else { @@ -279,7 +279,7 @@ static void FusedAttnForwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false); + false, false); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -298,7 +298,7 @@ static void FusedAttnForwardImpl( qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -311,7 +311,7 @@ static void FusedAttnForwardImpl( s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, dropout_probability, + q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { @@ -326,9 +326,9 @@ static void FusedAttnForwardImpl( dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -480,7 +480,7 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, query_workspace_tensor.data(), nullptr); + deterministic, false, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { nvte_fused_attn_bwd_kvpacked( q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), @@ -491,19 +491,19 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, query_workspace_tensor.data(), nullptr); + window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, query_workspace_tensor.data(), nullptr); + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); } else { NVTE_ERROR("Unsupported qkv_layout."); } @@ -546,7 +546,7 @@ static void FusedAttnBackwardImpl( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, - false); + false, false); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); @@ -568,7 +568,7 @@ static void FusedAttnBackwardImpl( q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, - workspace_tensor.data(), stream); + false, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto kv_shape = @@ -590,7 +590,7 @@ static void FusedAttnBackwardImpl( dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, + mask_type, softmax_type, window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; @@ -617,7 +617,7 @@ static void FusedAttnBackwardImpl( q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, - workspace_tensor.data(), stream); + false, workspace_tensor.data(), stream); } else { NVTE_ERROR("Unsupported qkv_layout."); } diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index d4903be90..147a85fc2 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -66,6 +66,7 @@ ) from transformer_engine.pytorch import export from transformer_engine.pytorch.export import is_in_onnx_export_mode +from transformer_engine.pytorch.graph import is_graph_capturing # Global vars for flash attn v2 and v3 imports flash_attn_cuda_bwd = None @@ -1199,6 +1200,7 @@ def forward( window_size, rng_gen, softmax_offset, + cuda_graph=is_graph_capturing(), ) # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1276,6 +1278,7 @@ def forward( rng_gen, softmax_offset, return_max_logit, + is_graph_capturing(), ) out = out_ out_ret = out_ @@ -1515,6 +1518,7 @@ def backward(ctx, d_out, *_args): ctx.softmax_type, ctx.window_size, ctx.deterministic, + is_graph_capturing(), ) # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 @@ -1579,6 +1583,7 @@ def backward(ctx, d_out, *_args): ctx.softmax_type, ctx.window_size, ctx.deterministic, + is_graph_capturing(), ) d_bias = None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f312cac79..00d609ab9 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -23,6 +23,7 @@ from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser +from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.constants import ( dist_group_type, TE_DType, @@ -33,6 +34,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) + from transformer_engine.pytorch.quantized_tensor import ( prepare_for_saving, restore_from_saved, @@ -715,6 +717,7 @@ def cp_p2p_fwd_fused_attn( cu_seqlens_kv_padded=cu_seqlens_kv_padded_, **fp8_meta_kwargs, return_max_logit=return_max_logit, + cuda_graph=is_graph_capturing(), ) if fp8: @@ -977,6 +980,7 @@ def cp_p2p_bwd_fused_attn( attn_mask_type=attn_mask_type_, attn_bias_type=attn_bias_type, deterministic=deterministic, + cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, ) @@ -2772,6 +2776,7 @@ def forward( cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], window_size=window_size_per_step[i], return_max_logit=return_max_logit, + cuda_graph=is_graph_capturing(), ) if return_max_logit: max_logit_per_step[i] = max_logit_[0] @@ -2986,6 +2991,7 @@ def backward(ctx, dout, *_args): attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, + cuda_graph=is_graph_capturing(), ) else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ @@ -3282,6 +3288,7 @@ def forward( softmax_type=softmax_type, softmax_offset=softmax_offset, return_max_logit=return_max_logit, + cuda_graph=is_graph_capturing(), ) if isinstance(out_, Float8Tensor): out_fp8 = out_ @@ -3559,6 +3566,7 @@ def backward(ctx, dout, *_args): attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, deterministic=ctx.deterministic, + cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, softmax_type=ctx.softmax_type, ) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 0d1c0b0c0..4278820e7 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1314,6 +1314,7 @@ def forward( inference_params=inference_params, softmax_type=self.softmax_type, return_max_logit=self.return_max_logit, + cuda_graph=is_graph_capturing(), ) global _attention_backends if is_in_onnx_export_mode(): diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7d4a4f86d..a08ba1419 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -231,6 +231,8 @@ class AttentionParams: The type of softmax operation. See DotProductAttention for details. return_max_logit: bool, default = `False` Whether to output max_logit. + cuda_graph: bool, default = `False` + Whether support for cuda graph capture is needed or not. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -260,6 +262,7 @@ class AttentionParams: inference_params: Optional[InferenceParams] = None softmax_type: str = "vanilla" return_max_logit: bool = False + cuda_graph: bool = False def __eq__(self, other): """ @@ -334,6 +337,7 @@ def get_attention_backend( inference_params = attention_params.inference_params softmax_type = attention_params.softmax_type return_max_logit = attention_params.return_max_logit + cuda_graph = attention_params.cuda_graph # Run config logger = logging.getLogger("DotProductAttention") @@ -979,6 +983,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt window_size[0], window_size[1], return_max_logit, + cuda_graph, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index eb43c75f6..e55ea2a54 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -140,6 +140,7 @@ def fused_attn_fwd( rng_gen: torch.Generator = None, softmax_offset: torch.Tensor = None, return_max_logit: bool = False, + cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -219,6 +220,8 @@ def fused_attn_fwd( See softmax_type in DotProductAttention for details. return_max_logit: bool, default = False whether to return the maximum attention score + cuda_graph: bool, default = False + whether or not cuda graph capture is enabled. Returns ---------- @@ -320,6 +323,7 @@ def fused_attn_fwd( rng_gen, rng_elts_per_thread, return_max_logit, + cuda_graph, ) if return_max_logit: @@ -367,6 +371,7 @@ def fused_attn_bwd( softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), deterministic: bool = False, + cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -439,6 +444,8 @@ def fused_attn_bwd( window and causal mask specifically. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. + cuda_graph: bool, default = False + whether or not cuda graph capture is enabled. Returns ---------- @@ -509,6 +516,7 @@ def fused_attn_bwd( s_quantizer, dp_quantizer, dqkv_quantizer, + cuda_graph, ) return output_tensors diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 79fb79842..43eab9654 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -76,7 +76,7 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit); + int64_t window_size_right, bool return_max_logit, bool cuda_graph); std::pair quantizer_helper(py::handle quantizer, const std::vector &shape, DType dtype, @@ -94,7 +94,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit); + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, @@ -106,7 +106,7 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer); + py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index f66c8aa61..d51aef406 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -45,12 +45,12 @@ NVTE_Fused_Attn_Backend get_fused_attn_backend( NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, - int64_t window_size_right, bool return_max_logit) { + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, - return_max_logit); + return_max_logit, cuda_graph); return fused_attention_backend; } @@ -107,7 +107,7 @@ std::vector fused_attn_fwd( const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, const std::optional SoftmaxOffset, const std::optional rng_gen, - size_t rng_elts_per_thread, bool return_max_logit) { + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { auto none = py::none(); // create QKV tensor wrappers @@ -229,7 +229,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -289,7 +289,7 @@ std::vector fused_attn_fwd( te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); }); @@ -312,7 +312,7 @@ std::vector fused_attn_bwd( const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer) { + py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { auto none = py::none(); // create QKV, O, dO tensor wrappers @@ -527,13 +527,14 @@ std::vector fused_attn_bwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -543,13 +544,14 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_dSoftmaxOffset.data(), te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), - te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size[0], - window_size[1], deterministic, workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers From 5978f1d7544fe0e4b06a036600c80c6a8c8203fe Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Fri, 7 Nov 2025 11:17:16 -0800 Subject: [PATCH 109/141] [JAX] Default to fused attention in JAX DPA (#2363) * Default to fused attention in JAX DPA Signed-off-by: Kshitij Lakhani * Consolidate documentation for DPA in JAX Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> * Correctly update the documentation for defaults in JAX DPA Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> --------- Signed-off-by: Kshitij Lakhani Signed-off-by: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- transformer_engine/jax/flax/transformer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 42c945124..86af6cf49 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -407,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment variable: - * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default). - * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention - kernel is not available on the system, a warning will be issued, and the module will - automatically fall back to the unfused backend. + * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention. + * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused + attention kernel is not available on the system, a warning will be issued, and the module + will automatically fall back to the unfused backend. .. note:: The DotProductAttention default setting enables non-deterministic kernels for reduced @@ -602,7 +602,8 @@ def __call__( else: assert bias is not None - enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) + # Use fused attn (if kernel check below passes) by default + enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1")) sequence_dim = 0 if self.transpose_batch_sequence else 1 seqlen_q = query.shape[sequence_dim] From d20311bd5ae6bae3cd83368736d482c25f52a1a3 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 7 Nov 2025 14:31:51 -0500 Subject: [PATCH 110/141] Update cudnn frontend to v1.16.0 (#2362) Signed-off-by: Kirthi Shankar Sivamani --- 3rdparty/cudnn-frontend | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 0b1577c8c..be6c079be 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 0b1577c8c83401237d601d0d0db5210506705396 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 From 3454f84da64cc117fde3c7c5286041d698576dea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 7 Nov 2025 23:08:22 +0100 Subject: [PATCH 111/141] [common] Remove kvpacked and qkvpacked attention functions for every kernel type. (#2287) * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * depracted compile time warning + \warning -> \deprecated Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 333 +++++++++-- .../fused_attn_f16_arbitrary_seqlen.cu | 530 +----------------- .../fused_attn_f16_arbitrary_seqlen.h | 47 -- .../fused_attn_f16_max512_seqlen.cu | 264 --------- .../fused_attn/fused_attn_f16_max512_seqlen.h | 37 -- .../common/fused_attn/fused_attn_fp8.cu | 418 -------------- .../common/fused_attn/fused_attn_fp8.h | 41 -- .../include/transformer_engine/fused_attn.h | 20 + 8 files changed, 302 insertions(+), 1388 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9c6e9b33d..ac6fefdc6 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,6 +15,74 @@ #include "fused_attn_fp8.h" #include "utils.h" +namespace { +// Helper function to create a tensor view with modified shape and optional pointer offset +transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source, + const std::vector &shape, + size_t offset_bytes = 0) { + transformer_engine::Tensor view = *source; + if (offset_bytes > 0) { + view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); + } + view.data.shape = shape; + view.nvte_tensor = 0; // Mark as unmanaged/local tensor view + return view; +} + +// Helper function to calculate stride for packed QKV tensor unpacking +size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for QKV packed tensor +std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor *qkv_tensor, + size_t h, size_t d) { + std::vector unpacked_shape; + if (qkv_tensor->data.shape.size() == 4) { + // T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1 + unpacked_shape = {qkv_tensor->data.shape[0], h, d}; + } else { + // BS3HD/SB3HD or BSH3D/SBH3D (5D) -> BSHD/SBHD (4D): remove dimension "3" at position 2 + unpacked_shape = {qkv_tensor->data.shape[0], qkv_tensor->data.shape[1], h, d}; + } + return unpacked_shape; +} + +// Helper function to calculate stride for packed KV tensor unpacking +size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h_kv, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for KV packed tensor +std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor *kv_tensor, + NVTE_QKV_Layout_Group layout_group, + NVTE_QKV_Format kv_format, size_t t_kv, size_t h_kv, + size_t d) { + std::vector unpacked_kv_shape; + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + unpacked_kv_shape = {t_kv, h_kv, d}; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD || + layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + unpacked_kv_shape = {kv_tensor->data.shape[0], kv_tensor->data.shape[1], h_kv, d}; + } + return unpacked_kv_shape; +} +} // namespace + // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { switch (qkv_layout) { @@ -436,6 +504,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -487,30 +557,62 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, - output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, return_max_logit, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, input_QKV, - input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, - input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, - stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -519,6 +621,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } } // NVTE fused attention BWD with packed QKV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_qkvpacked( const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, @@ -570,9 +674,25 @@ void nvte_fused_attn_bwd_qkvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, - input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, + &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, + input_cu_seqlens, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -588,12 +708,27 @@ void nvte_fused_attn_bwd_qkvpacked( if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - softmax_type, window_size_left, window_size_right, deterministic, input_QKV, input_O, - input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQKV, output_dBias, - output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, + &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, + &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -605,10 +740,26 @@ void nvte_fused_attn_bwd_qkvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, - input_S, input_output_dP, output_dQKV, input_cu_seqlens, - input_rng_state, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, + input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -617,6 +768,8 @@ void nvte_fused_attn_bwd_qkvpacked( } } // NVTE fused attention FWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -706,21 +859,40 @@ void nvte_fused_attn_fwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, - window_size_left, window_size_right, input_Q, input_KV, input_Bias, input_SoftmaxOffset, - output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); #else @@ -729,10 +901,20 @@ void nvte_fused_attn_fwd_kvpacked( #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -741,6 +923,8 @@ void nvte_fused_attn_fwd_kvpacked( } } // NVTE fused attention BWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, @@ -806,10 +990,23 @@ void nvte_fused_attn_bwd_kvpacked( if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif @@ -825,13 +1022,29 @@ void nvte_fused_attn_bwd_kvpacked( if (softmax_type != NVTE_VANILLA_SOFTMAX) { input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + // Create tensor views for dK, dV + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, - input_Q, input_KV, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, output_dQ, - output_dKV, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, - input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, - handle); + input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -843,11 +1056,25 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, + stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 950ced61b..14468b543 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -1037,532 +1037,6 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_logit, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS1 = nullptr; - void *devPtrS2 = nullptr; - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Max->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Max->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_Max->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_Max->data.dtype = DType::kFloat32; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Sum_Exp->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Sum_Exp->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_Sum_Exp->data.dtype = DType::kFloat32; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - } - - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_Max->data.dptr; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS2 = output_Sum_Exp->data.dptr; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_S->data.dptr; - } - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrSoftmaxOffset, - devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, devPtrdSoftmaxOffset, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, devPtrSeqOffsets, devPtrSeqOffsets, - get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - void *devPtrSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS1 = nullptr; - void *devPtrS2 = nullptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - size_t i = 0; - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Max->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_Max->data.dtype = DType::kFloat32; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_Sum_Exp->data.dtype = DType::kFloat32; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - } - - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } - - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = nullptr; - output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; - output_softmax_offset->data.dtype = DType::kFloat32; - } - - Aux_CTX_Tensors->size = i; - } else if (Aux_CTX_Tensors->size >= 2) { - if (return_max_logit) { - Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_Max->data.dptr; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS2 = output_Sum_Exp->data.dptr; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_S->data.dptr; - } - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_rng_state->data.dptr = rng_state->data.dptr; - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_bias->data.dptr = devPtrBias; - } - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_softmax_offset->data.dptr = devPtrSoftmaxOffset; - } - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, - devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, - devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, - devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - void *devPtrSoftmaxOffset = nullptr; - void *devPtrdSoftmaxOffset = nullptr; - if (softmax_type != NVTE_VANILLA_SOFTMAX) { - devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; - devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; - } - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, - devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, - devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, @@ -1604,8 +1078,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; + void *devPtrPageTableK = page_table_k ? page_table_k->data.dptr : nullptr; + void *devPtrPageTableV = page_table_v ? page_table_v->data.dptr : nullptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index a3181c629..872b798bb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -18,53 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, bool return_max_logit, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_QKV, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *cu_seqlens_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_QKV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, - Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, bool return_max_logit, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dKV, - Tensor *output_dBias, Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 89528fa3c..1028df645 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd_qkvpacked( - size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - const auto stride = 2 * num_head * head_dim; - - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = static_cast(input_Bias->data.dptr); - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; - output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrCuSeqlen = cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - const DType QKV_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, - "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); - - // Q shape is [b, s, h, d] - void *devPtrQ = input_Q->data.dptr; - - // KV shape is [b, s, 2, h, d] - const auto stride = 2 * num_head * head_dim; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, } } -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - - auto stride = 2 * num_head * head_dim; - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQKV shape is [b, s, 3, h, d] - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - - const auto qkv_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, - devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, - devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // Q shape is [b, s, h, d] - // KV shape is [b, s, 2, h, d] - auto stride = 2 * num_head * head_dim; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQ shape is [b, s, h, d] - // dKV shape is [b, s, 2, h, d] - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dKV->data.dptr; - void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 171fe846c..57b7afcf4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -18,25 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8901) -void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_size, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 7b85be972..5d806290a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -2407,424 +2407,6 @@ void fused_attn_fp8_bwd_impl_v1( } // namespace fused_attn #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = static_cast(devPtrQKV); - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, - devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQKV->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = devPtrQKV; - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQKV = output_dQKV->data.dptr; - void* devPtrdQ = devPtrdQKV; - void* devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void* devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - void* devPtrAmaxdQ = output_dQKV->amax.dptr; - void* devPtrAmaxdK = output_dQKV->amax.dptr; - void* devPtrAmaxdV = output_dQKV->amax.dptr; - void* devPtrScaledQ = output_dQKV->scale.dptr; - void* devPtrScaledK = output_dQKV->scale.dptr; - void* devPtrScaledV = output_dQKV->scale.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, - devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, - devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, - devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType O_type = output_O->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType dO_type = input_dO->data.dtype; - const DType dQKV_type = output_dQ->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - const DType O_type = input_O->data.dtype; - void* devPtrDescaleO = nullptr; - if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { - devPtrDescaleO = input_O->scale_inv.dptr; - } - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQ = output_dQ->data.dptr; - void* devPtrdKV = output_dKV->data.dptr; - void* devPtrdK = devPtrdKV; - void* devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dKV->amax.dptr; - void* devPtrAmaxdV = output_dKV->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dKV->scale.dptr; - void* devPtrScaledV = output_dKV->scale.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3daf45d16..c2efa2582 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -13,47 +13,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 40e6a0b4b..298dc6390 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -217,6 +217,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph); /*! \brief Compute dot product attention with packed QKV input. + * + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -270,6 +272,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, @@ -282,6 +287,8 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. + * + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -330,6 +337,9 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_qkvpacked( const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, @@ -340,6 +350,8 @@ void nvte_fused_attn_bwd_qkvpacked( NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. + * + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -401,6 +413,9 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, @@ -413,6 +428,8 @@ void nvte_fused_attn_fwd_kvpacked( int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. + * + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * * Support Matrix: \verbatim @@ -467,6 +484,9 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, From 5ea83432a400481b73e42de18bee7c206cb18fac Mon Sep 17 00:00:00 2001 From: Teddy Do Date: Mon, 10 Nov 2025 10:42:57 -0800 Subject: [PATCH 112/141] Move Triton to common (#2359) * move triton to common and change paths Signed-off-by: tdophung * Formatting Signed-off-by: tdophung --------- Signed-off-by: tdophung --- transformer_engine/common/triton/__init__.py | 5 + .../common/triton/cross_entropy.py | 252 ++++++++ transformer_engine/common/triton/pad.py | 59 ++ .../common/triton/permutation.py | 605 +++++++++++++++++ transformer_engine/pytorch/distributed.py | 3 +- transformer_engine/pytorch/triton/__init__.py | 2 +- .../pytorch/triton/cross_entropy.py | 252 +------- transformer_engine/pytorch/triton/pad.py | 55 +- .../pytorch/triton/permutation.py | 609 +----------------- 9 files changed, 943 insertions(+), 899 deletions(-) create mode 100644 transformer_engine/common/triton/__init__.py create mode 100644 transformer_engine/common/triton/cross_entropy.py create mode 100644 transformer_engine/common/triton/pad.py create mode 100644 transformer_engine/common/triton/permutation.py diff --git a/transformer_engine/common/triton/__init__.py b/transformer_engine/common/triton/__init__.py new file mode 100644 index 000000000..76c9b98d0 --- /dev/null +++ b/transformer_engine/common/triton/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Kernels written with OpenAI Triton.""" diff --git a/transformer_engine/common/triton/cross_entropy.py b/transformer_engine/common/triton/cross_entropy.py new file mode 100644 index 000000000..fc49ac20b --- /dev/null +++ b/transformer_engine/common/triton/cross_entropy.py @@ -0,0 +1,252 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient Cross Entropy kernels written with OpenAI Triton.""" + +import triton +import triton.language as tl + + +@triton.jit +def online_softmax_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes the m/d components on this TP rank for the online softmax. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride (int): The stride of the m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) + else: + X_y = float("-inf") + else: + X_y = float("-inf") + + m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( + tl.float32 + ) + block_max = tl.max(X_block) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + tl.store(m_d_X_y_ptr, m) + tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) + tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) + + +@triton.jit +def cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + loss_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + world_size, + ignore_idx, + n_cols, + n_non_ignore, + reduce_loss: tl.constexpr, + label_smoothing: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + loss_stride (int): The stride of the loss tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride: The stride of m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + world_size (int): The size of world involved in this distributed loss calculation. + ignore_idx (int): Tokens to be ignored for loss and gradient calculation. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + if y == ignore_idx: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride + + # Need to reduce the m/d/X_y values from other TP ranks + m = tl.load(m_d_X_y_ptr) + d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) + ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) + + for i in range(1, world_size): + offset = i * 3 * n_non_ignore * m_d_X_y_stride + access_ptr = m_d_X_y_ptr + offset + m_new = tl.load(access_ptr) + d_new = tl.load(access_ptr + m_d_X_y_stride) + X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) + + d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) + m = tl.maximum(m, m_new) + ori_X_y = tl.maximum(ori_X_y, X_y_new) + + # Label smoothing is a general case of normal cross entropy + scaled_x_sum = 0.0 + eps = label_smoothing / (n_cols * world_size) + + # 4. [Online softmax] second pass: calculate the gradients + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + grad_dtype = X_block.dtype + X_block = X_block.to(tl.float32) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + # Scale gradients based on reduction mode + # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore + # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here + if reduce_loss: + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps + tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + loss = -(ori_X_y - m - tl.log(d)) + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx) + # Apply the same conditional scaling logic for the target token + if reduce_loss: + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) + tl.store(X_ptr + y - vocab_start_idx, X_y) + + tl.store(loss_ptr, loss) + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + grad_output_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/transformer_engine/common/triton/pad.py b/transformer_engine/common/triton/pad.py new file mode 100644 index 000000000..8f15e7dcb --- /dev/null +++ b/transformer_engine/common/triton/pad.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient NVFP4 padding kernels written with OpenAI Triton . + +TODO(ksivamani): Documentation + +""" + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1), + ], + key=["out_dim0", "out_dim1"], +) +@triton.jit +def zero_pad_kernel( + inp_ptr, + out_ptr, + in_dim0: tl.constexpr, + in_dim1: tl.constexpr, + out_dim0: tl.constexpr, + out_dim1: tl.constexpr, + in_s0, + in_s1, + out_s0, + out_s1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + # tile over OUTPUT coordinates + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols + om = offs_m[:, None] + on = offs_n[None, :] + + # edge masking for output + out_mask = (om < out_dim0) & (on < out_dim1) + + # valid input region is simply top-left (no offsets) + in_mask = (om < in_dim0) & (on < in_dim1) + + # load valid input, else zero (masked load touches memory only where True) + x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0) + + # store to output (only within bounds of the output tile) + tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py new file mode 100644 index 000000000..3a3a32014 --- /dev/null +++ b/transformer_engine/common/triton/permutation.py @@ -0,0 +1,605 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient Permutation kernels written with OpenAI Triton.""" + +import triton +import triton.language as tl + +from triton.language import core +from triton.language.standard import _log2 +from packaging import version + + +# The following three argsort related kernels are adapted from +# the issue https://github.com/triton-lang/triton/issues/3698 + +get_int_dtype = core.get_int_dtype +if version.parse(triton.__version__) >= version.parse("3.5.0"): + get_int_dtype = triton.constexpr_function(get_int_dtype) + + +@triton.jit +def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)] + y = tl.reshape(x, shape) + z = tl.reshape(indices, shape) + + mask = tl.arange(0, 2)[None, :, None] + + l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + + l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) + r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) + + idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + il_value = l_value.to(idtype, bitcast=True) + ir_value = r_value.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix)) + ret = ix ^ flag1 + flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix)) + ind = indices ^ flag2 + + return ret.to(x.dtype, bitcast=True), ind + + +@triton.jit +def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + """ + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + """ + if order == 2: + shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = tl.full(x.shape, value=order, dtype=tl.int32) + for i in tl.static_range(stage): + x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims) + return x, indices + + +@triton.jit +def _argsort(x, indices, n_dims: tl.constexpr): + for i in tl.static_range(1, n_dims + 1): + x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims) + return x, indices + + +@triton.jit +def _row_id_map_pass_1_kernel( + # pointers + routing_map_ptr, + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # strides + stride_routing_map_token, + stride_routing_map_expert, + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + expert_token_mask = tl.load( + routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, + mask=(offset < num_tokens), + other=0, + ).to(tl.int32) + row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask + tl.store( + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + row_id_within_token_block, + mask=offset < num_tokens, + ) + n_tokens_per_block = tl.sum(expert_token_mask) + tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block) + + +@triton.jit +def _row_id_map_pass_2_kernel( + # pointers + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + WORKSPACE_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + row_id_within_token_block = tl.load( + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + mask=(offset < num_tokens), + other=0, + ) + + workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) + n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) + row_id = tl.where( + row_id_within_token_block == 0, + -1, + row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, + ) + tl.store( + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + row_id, + mask=(offset < num_tokens), + ) + + +@triton.jit +def _row_id_map_pass_3_kernel( + # pointers + row_id_map_ptr, + # sizes + num_experts: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + LOAD_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_dims: tl.constexpr = _log2(LOAD_SIZE) + off = tl.arange(0, LOAD_SIZE) + row_id_map = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off, + mask=off < num_experts, + other=-1, + ) + n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0)) + indices = off + sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert, + sorted_map, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + off) * stride_row_id_map_expert, + indices, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert, + n_routed, + ) + + +@triton.jit +def _permute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + scale_ptr, + permuted_probs_ptr, + permuted_scale_ptr, + # sizes + num_experts: tl.constexpr, + hidden_size: tl.constexpr, + scale_hidden_dim, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_probs_expert, + stride_scale_token, + stride_scale_hidden, + stride_permuted_probs_token, + stride_permuted_scale_token, + stride_permuted_scale_hidden, + # metas + PERMUTE_PROBS: tl.constexpr, + PERMUTE_SCALE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = cur_off < hidden_size + src_row = pid_t.to(tl.int64) + input_off = src_row * stride_input_token + cur_off * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + if PERMUTE_SCALE: + mask_scale = cur_off < scale_hidden_dim + scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden + scale = tl.load(scale_ptr + scale_off, mask=mask_scale) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ).to(tl.int64) + output_off = dst_row * stride_output_token + cur_off * stride_output_hidden + if PERMUTE_SCALE: + permuted_scale_off = ( + dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden + ) + tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) + if PERMUTE_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + if pid_h == 0: + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + if prob == 0.0: + # for routing_map padding + # dst_row != -1 and prob == 0.0 means that this slot is padded + tl.store(output_ptr + output_off, 0.0, mask=mask) + else: + tl.store(output_ptr + output_off, inp, mask=mask) + else: + tl.store(output_ptr + output_off, inp, mask=mask) + + +try: + _permute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_permute_kernel) +except RuntimeError: + pass + + +@triton.jit +def _unpermute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + merging_probs_ptr, + permuted_probs_ptr, + unpermuted_probs_ptr, + # sizes + num_experts: tl.constexpr, + hidden_size: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_permuted_probs_token, + stride_unpermuted_probs_token, + stride_unpermuted_probs_expert, + # metas + PROBS_LOAD_WIDTH: tl.constexpr, + WITH_MERGING_PROBS: tl.constexpr, + PERMUTE_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + data_type = input_ptr.dtype.element_ty + compute_type = tl.float32 + + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + if PERMUTE_PROBS: + # write 0.0 to probs_grad that are not routed + if pid_h == 0: + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + stride_unpermuted_probs_expert * map_load_off + ) + tl.store( + unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts + ) + accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + src_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ).to(tl.int64) + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + merging_prob_off = ( + pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob + accumulator += inp + if PERMUTE_PROBS: + if pid_h == 0: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + accumulator = accumulator.to(data_type) + dst_row = pid_t.to(tl.int64) + output_off = dst_row * stride_output_token + current_offset * stride_output_hidden + tl.store(output_ptr + output_off, accumulator, mask=mask) + + +try: + _unpermute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_unpermute_kernel) +except RuntimeError: + pass + + +@triton.jit +def _unpermute_bwd_with_merging_probs_kernel( + # pointers + fwd_output_grad_ptr, + fwd_input_grad_ptr, + fwd_input_ptr, + merging_probs_ptr, + merging_probs_grad_ptr, + row_id_map_ptr, + # sizes + num_experts: tl.constexpr, + hidden_size: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + stride_fwd_output_grad_token, + stride_fwd_output_grad_hidden, + stride_fwd_input_grad_token, + stride_fwd_input_grad_hidden, + stride_fwd_input_token, + stride_fwd_input_hidden, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_merging_probs_grad_token, + stride_merging_probs_grad_expert, + # metas + PROBS_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + data_type = fwd_output_grad_ptr.dtype.element_ty + compute_type = tl.float32 + + pid = tl.program_id(0) + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + token_probs_grad_off = ( + pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off + ) + tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts) + n_routed = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert + ).to(tl.int64) + expert_idx = tl.load( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + src_row = pid.to(tl.int64) + input_off = ( + src_row * stride_fwd_output_grad_token + + current_offset * stride_fwd_output_grad_hidden + ) + inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) + output_off = ( + dst_row * stride_fwd_input_grad_token + + current_offset * stride_fwd_input_grad_hidden + ) + tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) + + fwd_input_off = ( + dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden + ) + fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) + prob_grad_accum += fwd_input.to(compute_type) * inp + current_start += BLOCK_SIZE + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) + + +try: + _unpermute_bwd_with_merging_probs_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_unpermute_bwd_with_merging_probs_kernel) +except RuntimeError: + pass + + +@triton.jit +def _make_chunk_sort_map_kernel( + # pointers + split_sizes_ptr, + sorted_indices_ptr, + dst_rows_ptr, + # sizes + num_splits: tl.constexpr, + # metas + IDX_LOAD_WIDTH: tl.constexpr, +): + pid = tl.program_id(0) + + load_split_offset = tl.arange(0, IDX_LOAD_WIDTH) + sorted_indices = tl.load( + sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits + ) + + # get chunk idx of the current token in the input tensor + input_split_sizes = tl.load( + split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 + ).to(tl.int32) + input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) + input_chunk_idx = tl.sum(input_split_sizes_mask) + input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) + in_chunk_offset = pid - input_split_sizes_presum + + # get chunk idx of the current token in the output tensor + output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0) + output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1) + + # make row_id_map + output_split_sizes = tl.load( + split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits + ).to(tl.int32) + output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) + dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + tl.store(dst_rows_ptr + pid, dst_row) + + +@triton.jit +def _sort_chunks_by_map_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, + # sizes + hidden_size: tl.constexpr, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, + # metas + PERMUTE_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + FORWARD: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + if FORWARD: + src_row = pid_t.to(tl.int64) + dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) + else: + src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) + dst_row = pid_t.to(tl.int64) + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden + output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) + if PERMUTE_PROBS: + if pid_h == 0: + prob_off = src_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + + +try: + _sort_chunks_by_map_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_sort_chunks_by_map_kernel) +except RuntimeError: + pass diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 8c14d5ab7..e938509e5 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -29,12 +29,14 @@ import transformer_engine_torch as tex +from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from . import torch_version from .utils import ( is_non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm, ) + from .constants import dist_group_type from .quantization import FP8GlobalStateManager, autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer @@ -46,7 +48,6 @@ from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage -from .triton.pad import pad_columnwise_scale_inv from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py index 76c9b98d0..80766864d 100644 --- a/transformer_engine/pytorch/triton/__init__.py +++ b/transformer_engine/pytorch/triton/__init__.py @@ -2,4 +2,4 @@ # # See LICENSE for license information. -"""Kernels written with OpenAI Triton.""" +"""PyTorch wrappers for Triton kernels.""" diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 7cfff1da9..d7e2256e2 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Efficient Cross Entropy kernels written with OpenAI Triton.""" +"""PyTorch wrapper functions for Cross Entropy Triton kernels.""" from typing import Union from functools import reduce @@ -12,257 +12,17 @@ import torch.distributed as dist import triton -import triton.language as tl - - -@triton.jit -def online_softmax_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes the m/d components on this TP rank for the online softmax. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride (int): The stride of the m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) - else: - X_y = float("-inf") - else: - X_y = float("-inf") - - m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 - - # 3. [Online softmax] first pass: find max + sum - m = float("-inf") # m is the max value. use the notation from the paper - d = 0.0 # d is the sum. use the notation from the paper - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( - tl.float32 - ) - block_max = tl.max(X_block) - m_new = tl.maximum(m, block_max) - d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) - m = m_new - - tl.store(m_d_X_y_ptr, m) - tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) - tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) - - -@triton.jit -def cross_entropy_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - loss_ptr, - loss_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - world_size, - ignore_idx, - n_cols, - n_non_ignore, - reduce_loss: tl.constexpr, - label_smoothing: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes both cross entropy loss and the gradient of the input. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - loss_ptr: Pointer to tensor to store the loss. - loss_stride (int): The stride of the loss tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride: The stride of m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. - n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - if y == ignore_idx: - # set all X_ptr as 0 - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) - return - - loss_ptr += program_id * loss_stride - m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride - - # Need to reduce the m/d/X_y values from other TP ranks - m = tl.load(m_d_X_y_ptr) - d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) - ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) - - for i in range(1, world_size): - offset = i * 3 * n_non_ignore * m_d_X_y_stride - access_ptr = m_d_X_y_ptr + offset - m_new = tl.load(access_ptr) - d_new = tl.load(access_ptr + m_d_X_y_stride) - X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) - - d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) - m = tl.maximum(m, m_new) - ori_X_y = tl.maximum(ori_X_y, X_y_new) - - # Label smoothing is a general case of normal cross entropy - scaled_x_sum = 0.0 - eps = label_smoothing / (n_cols * world_size) - - # 4. [Online softmax] second pass: calculate the gradients - # dx_y = (softmax(x_y) - 1) / N - # dx_i = softmax(x_i) / N, i != y - # N is the number of non ignored elements in the batch - # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N - # = dx_i - (1 - label_smoothing) / N - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) - grad_dtype = X_block.dtype - X_block = X_block.to(tl.float32) - if label_smoothing > 0: - # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - # Scale gradients based on reduction mode - # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore - # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here - if reduce_loss: - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) - else: - X_block = tl.exp(X_block - m) / d - eps - tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) - - # We need tl.debug_barrier() to ensure the new result of X_ptr is written - tl.debug_barrier() - - # 5. Calculate the loss - - # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) - # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - loss = -(ori_X_y - m - tl.log(d)) - - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps - # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) - # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) - # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 - if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) - loss = loss * (1 - label_smoothing) + smooth_loss - - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx) - # Apply the same conditional scaling logic for the target token - if reduce_loss: - X_y += -(1 - label_smoothing) / (n_non_ignore) - else: - X_y += -(1 - label_smoothing) - tl.store(X_ptr + y - vocab_start_idx, X_y) - - tl.store(loss_ptr, loss) +from transformer_engine.common.triton.cross_entropy import ( + online_softmax_kernel, + cross_entropy_kernel, + element_mul_kernel, +) # The optimal maximum block size depends on your hardware, your kernel, and your dtype MAX_FUSED_SIZE = 65536 // 2 -@triton.jit -def element_mul_kernel( - X_ptr, - X_stride, - grad_output_ptr, - grad_output_stride, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. - The multiplication is performed in-place on the tensor pointed by X_ptr. - - Parameters: - X_ptr: Pointer to the input tensor. - X_stride (int): The stride of the input tensor. - grad_output_ptr: Pointer to the gradient output value. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - # Get the program ID and convert it to int64 to avoid overflow - program_id = tl.program_id(0).to(tl.int64) - - # Locate the start index - X_ptr += program_id * X_stride - - # Load the gradient output value - grad_output_ptr += program_id * grad_output_stride - grad_output = tl.load(grad_output_ptr) - - # Perform the element-wise multiplication - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) - tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) - - def cross_entropy_forward( _input: torch.Tensor, target: torch.Tensor, diff --git a/transformer_engine/pytorch/triton/pad.py b/transformer_engine/pytorch/triton/pad.py index 29b0daf31..790b8277b 100644 --- a/transformer_engine/pytorch/triton/pad.py +++ b/transformer_engine/pytorch/triton/pad.py @@ -2,63 +2,12 @@ # # See LICENSE for license information. -"""NVFP4 padding kernels - -TODO(ksivamani): Documentation - -""" +"""PyTorch wrapper functions for padding Triton kernels.""" import torch - import triton -import triton.language as tl - - -@triton.autotune( - configs=[ - triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2), - triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2), - triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1), - ], - key=["out_dim0", "out_dim1"], -) -@triton.jit -def zero_pad_kernel( - inp_ptr, - out_ptr, - in_dim0: tl.constexpr, - in_dim1: tl.constexpr, - out_dim0: tl.constexpr, - out_dim1: tl.constexpr, - in_s0, - in_s1, - out_s0, - out_s1, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - """Pads a tensor assuming it's a columnwise scaling inverse.""" - - # tile over OUTPUT coordinates - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows - offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols - om = offs_m[:, None] - on = offs_n[None, :] - - # edge masking for output - out_mask = (om < out_dim0) & (on < out_dim1) - - # valid input region is simply top-left (no offsets) - in_mask = (om < in_dim0) & (on < in_dim1) - - # load valid input, else zero (masked load touches memory only where True) - x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0) - # store to output (only within bounds of the output tile) - tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask) +from transformer_engine.common.triton.pad import zero_pad_kernel def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor: diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 1474a664c..da22299fe 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -2,197 +2,23 @@ # # See LICENSE for license information. -"""Permutation kernels written with OpenAI Triton.""" +"""PyTorch wrapper functions for Permutation Triton kernels.""" from typing import Union import torch import triton -import triton.language as tl -from triton.language import core -from triton.language.standard import _log2 -from packaging import version - - -# The following three argsort related kernels are adapted from -# the issue https://github.com/triton-lang/triton/issues/3698 - -get_int_dtype = core.get_int_dtype -if version.parse(triton.__version__) >= version.parse("3.5.0"): - get_int_dtype = triton.constexpr_function(get_int_dtype) - - -@triton.jit -def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): - n_outer: tl.constexpr = x.numel >> n_dims - shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)] - y = tl.reshape(x, shape) - z = tl.reshape(indices, shape) - - mask = tl.arange(0, 2)[None, :, None] - - l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to( - x.dtype - ) - r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to( - x.dtype - ) - - l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) - r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) - - idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) - - il_value = l_value.to(idtype, bitcast=True) - ir_value = r_value.to(idtype, bitcast=True) - ix = x.to(idtype, bitcast=True) - - flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix)) - ret = ix ^ flag1 - flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix)) - ind = indices ^ flag2 - - return ret.to(x.dtype, bitcast=True), ind - - -@triton.jit -def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): - n_outer: tl.constexpr = x.numel >> n_dims - tl.static_assert(stage <= n_dims) - """ - order_type 0 == ascending - order_type 1 == descending - order_type 2 == alternating - """ - if order == 2: - shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage] - flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) - else: - flip = tl.full(x.shape, value=order, dtype=tl.int32) - for i in tl.static_range(stage): - x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims) - return x, indices - - -@triton.jit -def _argsort(x, indices, n_dims: tl.constexpr): - for i in tl.static_range(1, n_dims + 1): - x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims) - return x, indices - - -@triton.jit -def _row_id_map_pass_1_kernel( - # pointers - routing_map_ptr, - row_id_map_ptr, - workspace_ptr, - # sizes - num_tokens, - # strides - stride_routing_map_token, - stride_routing_map_expert, - stride_row_id_map_token, - stride_row_id_map_expert, - # metas - BLOCK_SIZE: tl.constexpr, -): - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - expert_token_mask = tl.load( - routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, - mask=(offset < num_tokens), - other=0, - ).to(tl.int32) - row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask - tl.store( - row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, - row_id_within_token_block, - mask=offset < num_tokens, - ) - n_tokens_per_block = tl.sum(expert_token_mask) - tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block) - - -@triton.jit -def _row_id_map_pass_2_kernel( - # pointers - row_id_map_ptr, - workspace_ptr, - # sizes - num_tokens, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - # metas - WORKSPACE_LOAD_WIDTH: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n - offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - row_id_within_token_block = tl.load( - row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, - mask=(offset < num_tokens), - other=0, - ) - - workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) - n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) - row_id = tl.where( - row_id_within_token_block == 0, - -1, - row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, - ) - tl.store( - row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, - row_id, - mask=(offset < num_tokens), - ) - - -@triton.jit -def _row_id_map_pass_3_kernel( - # pointers - row_id_map_ptr, - # sizes - num_experts: tl.constexpr, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - # metas - LOAD_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - n_dims: tl.constexpr = _log2(LOAD_SIZE) - off = tl.arange(0, LOAD_SIZE) - row_id_map = tl.load( - row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off, - mask=off < num_experts, - other=-1, - ) - n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0)) - indices = off - sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims) - tl.store( - row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert, - sorted_map, - mask=off < n_routed, - ) - tl.store( - row_id_map_ptr - + pid * stride_row_id_map_token - + (num_experts + off) * stride_row_id_map_expert, - indices, - mask=off < n_routed, - ) - tl.store( - row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert, - n_routed, - ) +from transformer_engine.common.triton.permutation import ( + _row_id_map_pass_1_kernel, + _row_id_map_pass_2_kernel, + _row_id_map_pass_3_kernel, + _permute_kernel, + _unpermute_kernel, + _unpermute_bwd_with_merging_probs_kernel, + _make_chunk_sort_map_kernel, + _sort_chunks_by_map_kernel, +) def make_row_id_map( @@ -292,103 +118,6 @@ def make_row_id_map( return row_id_map -@triton.jit -def _permute_kernel( - # pointers - input_ptr, - output_ptr, - row_id_map_ptr, - probs_ptr, - scale_ptr, - permuted_probs_ptr, - permuted_scale_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, - scale_hidden_dim, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_probs_token, - stride_probs_expert, - stride_scale_token, - stride_scale_hidden, - stride_permuted_probs_token, - stride_permuted_scale_token, - stride_permuted_scale_hidden, - # metas - PERMUTE_PROBS: tl.constexpr, - PERMUTE_SCALE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_t = tl.program_id(0) - pid_h = tl.program_id(1) - cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = cur_off < hidden_size - src_row = pid_t.to(tl.int64) - input_off = src_row * stride_input_token + cur_off * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - if PERMUTE_SCALE: - mask_scale = cur_off < scale_hidden_dim - scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden - scale = tl.load(scale_ptr + scale_off, mask=mask_scale) - n_routed = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + num_experts * 2 * stride_row_id_map_expert - ) - for idx in tl.range(n_routed): - dst_row = tl.load( - row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ).to(tl.int64) - output_off = dst_row * stride_output_token + cur_off * stride_output_hidden - if PERMUTE_SCALE: - permuted_scale_off = ( - dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden - ) - tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) - if PERMUTE_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + prob_off) - if pid_h == 0: - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - if prob == 0.0: - # for routing_map padding - # dst_row != -1 and prob == 0.0 means that this slot is padded - tl.store(output_ptr + output_off, 0.0, mask=mask) - else: - tl.store(output_ptr + output_off, inp, mask=mask) - else: - tl.store(output_ptr + output_off, inp, mask=mask) - - -try: - _permute_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_permute_kernel) -except RuntimeError: - pass - - def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -468,116 +197,6 @@ def permute_with_mask_map( return output, permuted_scale, permuted_probs -@triton.jit -def _unpermute_kernel( - # pointers - input_ptr, - output_ptr, - row_id_map_ptr, - merging_probs_ptr, - permuted_probs_ptr, - unpermuted_probs_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_merging_probs_token, - stride_merging_probs_expert, - stride_permuted_probs_token, - stride_unpermuted_probs_token, - stride_unpermuted_probs_expert, - # metas - PROBS_LOAD_WIDTH: tl.constexpr, - WITH_MERGING_PROBS: tl.constexpr, - PERMUTE_PROBS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - data_type = input_ptr.dtype.element_ty - compute_type = tl.float32 - - pid_t = tl.program_id(0) - pid_h = tl.program_id(1) - current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - if PERMUTE_PROBS: - # write 0.0 to probs_grad that are not routed - if pid_h == 0: - map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) - unpermuted_prob_off = ( - pid_t * stride_unpermuted_probs_token - + stride_unpermuted_probs_expert * map_load_off - ) - tl.store( - unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts - ) - accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - n_routed = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + num_experts * 2 * stride_row_id_map_expert - ) - for idx in tl.range(n_routed): - src_row = tl.load( - row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ).to(tl.int64) - input_off = src_row * stride_input_token + current_offset * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - if WITH_MERGING_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - merging_prob_off = ( - pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - inp *= merging_prob - accumulator += inp - if PERMUTE_PROBS: - if pid_h == 0: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - unpermuted_prob_off = ( - pid_t * stride_unpermuted_probs_token - + expert_idx * stride_unpermuted_probs_expert - ) - permuted_prob_off = src_row * stride_permuted_probs_token - prob = tl.load(permuted_probs_ptr + permuted_prob_off) - tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) - accumulator = accumulator.to(data_type) - dst_row = pid_t.to(tl.int64) - output_off = dst_row * stride_output_token + current_offset * stride_output_hidden - tl.store(output_ptr + output_off, accumulator, mask=mask) - - -try: - _unpermute_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_unpermute_kernel) -except RuntimeError: - pass - - def unpermute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -644,110 +263,6 @@ def unpermute_with_mask_map( return output, unpermuted_probs -@triton.jit -def _unpermute_bwd_with_merging_probs_kernel( - # pointers - fwd_output_grad_ptr, - fwd_input_grad_ptr, - fwd_input_ptr, - merging_probs_ptr, - merging_probs_grad_ptr, - row_id_map_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - stride_fwd_output_grad_token, - stride_fwd_output_grad_hidden, - stride_fwd_input_grad_token, - stride_fwd_input_grad_hidden, - stride_fwd_input_token, - stride_fwd_input_hidden, - stride_merging_probs_token, - stride_merging_probs_expert, - stride_merging_probs_grad_token, - stride_merging_probs_grad_expert, - # metas - PROBS_LOAD_WIDTH: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - data_type = fwd_output_grad_ptr.dtype.element_ty - compute_type = tl.float32 - - pid = tl.program_id(0) - map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) - token_probs_grad_off = ( - pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off - ) - tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts) - n_routed = tl.load( - row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert - ) - for idx in tl.range(n_routed): - dst_row = tl.load( - row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert - ).to(tl.int64) - expert_idx = tl.load( - row_id_map_ptr - + pid * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - src_row = pid.to(tl.int64) - input_off = ( - src_row * stride_fwd_output_grad_token - + current_offset * stride_fwd_output_grad_hidden - ) - inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - merging_prob_off = ( - pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - output = inp * merging_prob - output = output.to(data_type) - output_off = ( - dst_row * stride_fwd_input_grad_token - + current_offset * stride_fwd_input_grad_hidden - ) - tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) - - fwd_input_off = ( - dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden - ) - fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) - prob_grad_accum += fwd_input.to(compute_type) * inp - current_start += BLOCK_SIZE - probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) - probs_grad_off = ( - pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert - ) - tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) - - -try: - _unpermute_bwd_with_merging_probs_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_unpermute_bwd_with_merging_probs_kernel) -except RuntimeError: - pass - - def unpermute_with_mask_map_bwd_with_merging_probs( fwd_output_grad: torch.Tensor, row_id_map: torch.Tensor, @@ -813,47 +328,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( return act_grad, merging_probs_grad -@triton.jit -def _make_chunk_sort_map_kernel( - # pointers - split_sizes_ptr, - sorted_indices_ptr, - dst_rows_ptr, - # sizes - num_splits: tl.constexpr, - # metas - IDX_LOAD_WIDTH: tl.constexpr, -): - pid = tl.program_id(0) - - load_split_offset = tl.arange(0, IDX_LOAD_WIDTH) - sorted_indices = tl.load( - sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits - ) - - # get chunk idx of the current token in the input tensor - input_split_sizes = tl.load( - split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 - ).to(tl.int32) - input_split_sizes_cumsum = tl.cumsum(input_split_sizes) - input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) - input_chunk_idx = tl.sum(input_split_sizes_mask) - input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) - in_chunk_offset = pid - input_split_sizes_presum - - # get chunk idx of the current token in the output tensor - output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0) - output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1) - - # make row_id_map - output_split_sizes = tl.load( - split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits - ).to(tl.int32) - output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) - dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset - tl.store(dst_rows_ptr + pid, dst_row) - - def make_chunk_sort_map( split_sizes: torch.Tensor, sorted_indices: torch.Tensor, @@ -886,67 +360,6 @@ def make_chunk_sort_map( return row_id_map -@triton.jit -def _sort_chunks_by_map_kernel( - # pointers - input_ptr, - output_ptr, - row_id_map_ptr, - probs_ptr, - permuted_probs_ptr, - # sizes - hidden_size: tl.constexpr, - # strides - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_probs_token, - stride_permuted_probs_token, - # metas - PERMUTE_PROBS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - FORWARD: tl.constexpr, -): - pid_t = tl.program_id(0) - pid_h = tl.program_id(1) - if FORWARD: - src_row = pid_t.to(tl.int64) - dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) - else: - src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) - dst_row = pid_t.to(tl.int64) - current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden - output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden - inp = tl.load(input_ptr + input_offsets, mask=mask) - tl.store(output_ptr + output_offsets, inp, mask=mask) - if PERMUTE_PROBS: - if pid_h == 0: - prob_off = src_row * stride_probs_token - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - - -try: - _sort_chunks_by_map_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_sort_chunks_by_map_kernel) -except RuntimeError: - pass - - def sort_chunks_by_map( inp: torch.Tensor, row_id_map: torch.Tensor, From 7a5859834084c8c9cc88bc1711270e0ee232e6bc Mon Sep 17 00:00:00 2001 From: Teddy Do Date: Mon, 10 Nov 2025 10:59:22 -0800 Subject: [PATCH 113/141] [JAX] Fused layers argument default values changed (#2347) * Changing default activations in MLP, TransformerLayer, dropout rate after FC1 to 0, and return_layernorm_output to False Signed-off-by: tdophung * Fixing the failing tests by hard coding arguments to the previous values instead of relying on newer default values Signed-off-by: tdophung * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tdophung Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_distributed_layernorm_mlp.py | 2 ++ tests/jax/utils.py | 12 ++++++------ transformer_engine/jax/flax/module.py | 16 ++++++++-------- transformer_engine/jax/flax/transformer.py | 8 ++++---- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 339097e9c..667840da2 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -389,6 +389,7 @@ def _test_layernorm_mlp( intermediate_dim=INTERMEDIATE, activations=activation_type, use_bias=use_bias, + return_layernorm_output=True, ) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( @@ -417,6 +418,7 @@ def _test_layernorm_mlp( dot_1_input_axes=DOT_1_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES, name="mlp", + return_layernorm_output=True, ) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( diff --git a/tests/jax/utils.py b/tests/jax/utils.py index c28e68a15..bbe8e6582 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -364,9 +364,9 @@ class MlpBlock(nn.Module): transpose_batch_sequence: bool intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ("relu",) + activations: Sequence[Union[str, Callable]] = ("gelu",) kernel_init: Initializer = None - intermediate_dropout_rate: float = 0.1 + intermediate_dropout_rate: float = 0.0 intermediate_dropout_dims: Sequence[int] = () use_bias: bool = False dtype: Any = jnp.float32 @@ -1035,14 +1035,14 @@ class EncoderLayer(nn.Module): hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 + intermediate_dropout: float = 0.0 intermediate_dropout_dims: Sequence[int] = () transpose_batch_sequence: bool = True float32_attention_logits: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 - mlp_activations: Sequence[str] = ("relu",) + mlp_activations: Sequence[str] = ("gelu",) use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False @@ -1199,14 +1199,14 @@ class DecoderLayer(nn.Module): hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 + intermediate_dropout: float = 0.0 intermediate_dropout_dims: Sequence[int] = () transpose_batch_sequence: bool = True float32_attention_logits: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 - mlp_activations: Sequence[str] = ("relu",) + mlp_activations: Sequence[str] = ("gelu",) use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c54ecb236..33ea61098 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -597,7 +597,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): bias_axes: Tuple[str, ...], default = () The name of axes used to shard bias with a corresponding mesh, only used when :attr:`use_bias=True`. - return_layernorm_output: bool, default = True + return_layernorm_output: bool, default = False Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs. enable_low_rank_adaptation: bool, default = False @@ -644,7 +644,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): use_bias: bool = False bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () - return_layernorm_output: bool = True + return_layernorm_output: bool = False enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None @@ -891,10 +891,10 @@ class LayerNormMLP(TransformerEngineBase): The name of axes used to shard bias with a corresponding mesh for the weight of the second dense layer transformation. Only used when :attr:`use_bias=True`. - return_layernorm_output: bool, default = True + return_layernorm_output: bool, default = False Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs. - activations: Sequence[Union[str, Callable]], default = ('relu',) + activations: Sequence[Union[str, Callable]], default = ('gelu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. activation_params: dict, default = None @@ -903,7 +903,7 @@ class LayerNormMLP(TransformerEngineBase): need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. - intermediate_dropout_rate: float, default = 0.1 + intermediate_dropout_rate: float, default = 0.0 Dropout probability for the dropout op after the :attr:`activations`. intermediate_hidden_dropout_dims: Sequence[int], default = () Dimensions that will share the same dropout mask for hidden @@ -959,11 +959,11 @@ class LayerNormMLP(TransformerEngineBase): bias_init: Initializer = nn.initializers.zeros bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) - return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",) + return_layernorm_output: bool = False + activations: Sequence[Union[str, Callable]] = ("gelu",) activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" - intermediate_dropout_rate: float = 0.1 + intermediate_dropout_rate: float = 0.0 intermediate_hidden_dropout_dims: Sequence[int] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 86af6cf49..d096e7997 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1620,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Dimensions that will share the same dropout mask for hidden attention_dropout: float, default = 0.1 Dropout probability for the dropout op during multi-head attention. - intermediate_dropout: float, default = 0.1 + intermediate_dropout: float, default = 0.0 Dropout probability for the dropout op after FC1 layer. intermediate_dropout_dims: Sequence[int], default = () Dimensions that will share the same dropout mask for hidden after FC1 layer. @@ -1635,7 +1635,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') Used for initializing weights of FC1 and FC2 layers. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). - mlp_activations: Sequence[str], default = ('relu', ) + mlp_activations: Sequence[str], default = ('gelu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. mlp_activation_params: dict = None @@ -1755,12 +1755,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 + intermediate_dropout: float = 0.0 intermediate_dropout_dims: Sequence[int] = () dropout_rng_name: str = "dropout" mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None - mlp_activations: Sequence[str] = ("relu",) + mlp_activations: Sequence[str] = ("gelu",) mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros From 29537c96d06c4f3965fd3bcc668810dbb4245aaf Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 10 Nov 2025 19:31:55 -0800 Subject: [PATCH 114/141] [PyTorch] FSDP2 Support for TE (#2245) * fix for float8 tensor fsdp2 training Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * zeros_like should return fp32 for fsdp2 to work Signed-off-by: Varun Thumbe * minor cleanup Signed-off-by: Varun Thumbe * fix unsharded weights not releasing memory Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement using fsdp preallgather and postallgather functions Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * FSDP2 works on Hopper/L40 Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor comment Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * some fixes for fp8 + handwavy changes for mxfp8 Signed-off-by: Varun Thumbe * only transpose saved for backward pass allgather in case of L40/Hoppergst Signed-off-by: Varun Thumbe * missed minor change to hopper use-case Signed-off-by: Varun Thumbe * communicate only required data in mxfp8, fix for updating weight usages when required instead of doing upfront in fwd pass Signed-off-by: Varun Thumbe * changes for meta Dtensors for weights and better all gather data handling in fsdp hook functions Signed-off-by: Varun Thumbe * better solution to figure out forward pass in FSDP2 Signed-off-by: Varun Thumbe * adress review comments Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * everything functioning except hack for transformerlayer Signed-off-by: Varun Thumbe * fix merge conflict Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert change of commit id for cudnnt-frontend Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * unnecessary change Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor issues with linting, add some comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor stuff Signed-off-by: Varun Thumbe * revert space removal Add default usage handling for rowwise and columnwise data. Signed-off-by: vthumbe1503 * fix the fsdp state collection issue, and minor review comments addressing Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert change for dgrad redundant computation Signed-off-by: Varun Thumbe * bug: get fsdp param group's training state instead of root training state; address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * address review comments Signed-off-by: Varun Thumbe * address coderabbit review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * address review comments Signed-off-by: Varun Thumbe * adress review comments; fix fp8 allgather test to do after fsdp lazy init Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * remove detach Signed-off-by: Varun Thumbe * do what makes sense Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/tensor/float8_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: vthumbe1503 * address review comments Signed-off-by: Varun Thumbe * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * adress review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * have better dtype for fsdp_post_all_gather arguments Signed-off-by: Varun Thumbe * minor comment Signed-off-by: Varun Thumbe * improve comment Signed-off-by: Varun Thumbe * fix the error in CI Signed-off-by: Varun Thumbe * minor comment add Signed-off-by: Varun Thumbe * accidentally removed view function Signed-off-by: Varun Thumbe * fix minor bug for h100 Signed-off-by: Varun Thumbe * minor addition Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * implement padding removal/addition for allgather Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/tensor/mxfp8_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint error Signed-off-by: Varun Thumbe * adress review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * improve the reset parameter logic for dtensors Signed-off-by: Varun Thumbe * other cosmetic changes Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cosmetic changes Signed-off-by: Varun Thumbe * cosmetic changes Signed-off-by: Varun Thumbe * Update transformer_engine/pytorch/module/layernorm_linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- tests/pytorch/distributed/run_fsdp2_model.py | 321 +++++++++++++---- tests/pytorch/distributed/test_torch_fsdp2.py | 18 +- transformer_engine/pytorch/distributed.py | 37 ++ transformer_engine/pytorch/module/base.py | 43 ++- .../pytorch/module/grouped_linear.py | 24 +- .../pytorch/module/layernorm_linear.py | 19 +- .../pytorch/module/layernorm_mlp.py | 20 +- transformer_engine/pytorch/module/linear.py | 12 +- .../pytorch/quantized_tensor.py | 12 +- .../pytorch/tensor/float8_tensor.py | 230 +++++++++++- .../pytorch/tensor/mxfp8_tensor.py | 341 +++++++++++++++++- 11 files changed, 928 insertions(+), 149 deletions(-) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index d3f8c82ba..c34329924 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -9,57 +9,73 @@ import argparse import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, +) import torch import torch.distributed as dist +from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn, optim from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh +from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext +LOCAL_RANK = None -class SimpleNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super(SimpleNet, self).__init__() - self.fc1 = te.Linear(input_size, hidden_size) - self.fc2 = te.Linear(hidden_size, output_size) - def forward(self, x): - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x - - -def save_custom_attrs(module): - custom_attrs = {} - for name, param in module.named_parameters(): - attrs = vars(param) - custom_attrs[name] = {k: v for k, v in attrs.items()} - return custom_attrs - - -def restore_custom_attrs(module, custom_attrs): - for name, param in module.named_parameters(): - if name in custom_attrs: - for attr_name, attr_value in custom_attrs[name].items(): - setattr(param, attr_name, attr_value) +def dist_print(msg): + if LOCAL_RANK == 0: + print(msg) def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") - parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") - parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") - parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") - parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads") + parser.add_argument("--head-dim", type=int, default=64, help="Attention head size") + parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input") + parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input") + parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.") parser.add_argument( "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) + parser.add_argument( + "--recipe", + type=str, + default="mx_fp8_block_scaling", + help="Quantizer type.", + choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"], + ) + parser.add_argument( + "--layer-type", + type=str, + default="TransformerLayer", + choices=[ + "Linear", + "LayerNormLinear", + "LayerNormMLP", + "MultiheadAttention", + "TransformerLayer", + ], + help="Transformer Engine layer type", + ) + parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model") parser.add_argument( "--iter", type=int, default=10, help="Number of iterations for forward pass" ) + parser.add_argument( + "--device", + type=str, + default="meta", + help="Device to run the model on.", + choices=["cuda", "meta"], + ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") # Adding hsdp_dim as a list argument, comma-separated parser.add_argument( @@ -74,10 +90,170 @@ def _parse_args(argv=None, namespace=None): return args -sub_modules_to_wrap = [te.Linear] +## Methods to help initialize the TE model in an FSDP2 setting +## with required configurations based on command line args +def get_te_layer_from_string(layer_name): + te_layer_types = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + te_layer_names = [layer.__name__ for layer in te_layer_types] + te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types)) + if layer_name.lower() not in te_layer_map.keys(): + raise argparse.ArgumentTypeError( + f'"{layer_name}" is not a valid Transformer Engine layer, ' + f"please choose layer from {te_layer_names}." + ) + return te_layer_map[layer_name.lower()] + + +def get_recipe_from_string(recipe, fp8_format=Format.HYBRID): + if recipe == "delayed_scaling": + return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + elif recipe == "current_scaling": + return Float8CurrentScaling(fp8_format=fp8_format) + elif recipe == "mx_fp8_block_scaling": + return MXFP8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError(f"Unknown quantizer type: {recipe}") + + +def init_te_model(config): + hidden_size = config.num_heads * config.head_dim + args = [hidden_size, hidden_size] + inp_shape = [config.seq_length, config.batch_size, hidden_size] + out_shape = [config.seq_length, config.batch_size, hidden_size] + if config.params_dtype == "float16": + params_dtype = torch.float16 + elif config.params_dtype == "bfloat16": + params_dtype = torch.bfloat16 + else: + params_dtype = torch.float32 + kwargs = { + "params_dtype": params_dtype, + } + kwargs["device"] = config.device + + layer_type = get_te_layer_from_string(config.layer_type) + # We are creating model in a way so that we can test both reshard_after_forward=True/False cases. + # more details below. + if layer_type in [te.MultiheadAttention, te.TransformerLayer]: + # For this case, we are creating a model that resemebles production use-cases + # wherein there are mltiple TransformerLayers in the model. And we would need + # to shard each transformer layer. Since each transformer layer is not a root module, + # FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model. + args[1] *= 4 # FFN hidden size + args.append(config.num_heads) + kwargs["fuse_qkv_params"] = True + if layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)]) + elif layer_type == te.LayerNormLinear: + # For this case, we are creating a model with just one LayerNormLinear layer + # so that the model itself is a root module, and FSDP2's fully_shard assigns + # reshard_after_forward=True for the parameters of these model. + args[1] *= 3 # QKV projection + out_shape[-1] *= 3 + model = layer_type(*args, **kwargs) + else: + model = layer_type(*args, **kwargs) + + return model, inp_shape, out_shape + + +def get_device_mesh(world_size, sharding_dims): + dist_print(f"sharding-dims:{sharding_dims}") + device_ids = list(range(world_size)) + if sharding_dims is None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(sharding_dims) == 1: + assert sharding_dims[0] == world_size + mesh = DeviceMesh("cuda", device_ids) + elif len(sharding_dims) == 2: # HSDP + assert sharding_dims[0] * sharding_dims[1] == world_size + mesh = init_device_mesh( + "cuda", + (sharding_dims[0], sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + return mesh + + +def shard_model_with_fsdp2(model, mesh): + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + return model + + +#### Methods to save the custom attributes of QuantizedTensors before sharding +#### them with FSDP2, and restore them after sharding. +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + if isinstance(param, QuantizedTensor): + # Ignore FP8 metadata attributes. Otherwise we will save duplicate copies + # for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save. + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] + else: + ignore_keys = [] + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +@torch.no_grad() +def test_fp8_fsdp2_allgather(model): + # Do manual allgather in fp32 and match against fp8 allgather done + # with fsdp2 + # FP32 manual weight allgather + fp32_allgathered_params = {} + for name, param in model.named_parameters(): + assert isinstance(param, DTensor) + local_tensor = param._local_tensor + device_mesh = param.device_mesh + dist_group = ( + device_mesh.get_group(mesh_dim="shard") + if device_mesh.ndim > 1 + else device_mesh.get_group() + ) + # Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch + # for local_tensor will go down the dequantization route. + gathered_tensor = [ + torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group)) + ] + dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group) + full_tensor = torch.cat(gathered_tensor, dim=0) + fp32_allgathered_params[name] = full_tensor + # FP8 allgather using FSDP2 + for module in model.modules(): + # Not all modules are wrapped/sharded with FSDP2. + if hasattr(module, "unshard"): + module.unshard() + # Make sure allgathered parameters match exactly + for name, param in model.named_parameters(): + assert torch.allclose(param.dequantize(), fp32_allgathered_params[name]) + # Revert model to original sharded state + for module in model.modules(): + # Not all modules are wrapped/sharded with FSDP2. + if hasattr(module, "reshard"): + module.reshard() def _train(args): + global LOCAL_RANK assert "TORCHELASTIC_RUN_ID" in os.environ WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) @@ -103,74 +279,69 @@ def _train(args): # FP8 Configuration fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") - - # Create build context manager - if args.fp8_init: - from transformer_engine.pytorch import quantized_model_init + fp8_recipe = get_recipe_from_string(args.recipe, fp8_format) - build_model_context = quantized_model_init() + build_model_context_args = {} + if not args.fp8_init: + # Build model context (FP8 init) + build_model_context = nullcontext else: - build_model_context = nullcontext() + from transformer_engine.pytorch import fp8_model_init - # Build the model with the specified context - with build_model_context: - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) + build_model_context = fp8_model_init + build_model_context_args["enabled"] = True + build_model_context_args["recipe"] = fp8_recipe - # Move the model to the correct device - model.to(device) + dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB") + # Create the model on the meta/cuda device as per args + with build_model_context(**build_model_context_args): + model, inp_shape, out_shape = init_te_model(args) + dist_print( + f"Memory after model init on device {args.device}:" + f" {torch.cuda.memory_allocated(device)/1e6} MB" + ) - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") # Creating a DeviceMesh for fully_shard world_size = int(WORLD_SIZE) - device_ids = list(range(world_size)) - if LOCAL_RANK == 0: - print(f"sharding-dims:{args.sharding_dims}") # Setup the sharding mesh for FSDP/HSDP - if args.sharding_dims == None: # FSDP - mesh = DeviceMesh("cuda", device_ids) - elif len(args.sharding_dims) == 1: - assert args.sharding_dims[0] == device_ids[-1] + 1 - mesh = DeviceMesh("cuda", device_ids) - elif len(args.sharding_dims) == 2: # HSDP - assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 - mesh = init_device_mesh( - "cuda", - (args.sharding_dims[0], args.sharding_dims[1]), - mesh_dim_names=("replicate", "shard"), - ) - else: - assert False - - # Apply FSDP/HSDP + mesh = get_device_mesh(world_size, args.sharding_dims) custom_attrs = save_custom_attrs(model) - for sub_module in model.modules(): - if any( - isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap - ): - fully_shard(sub_module, mesh=mesh) - fully_shard(model, mesh=mesh) + model = shard_model_with_fsdp2(model, mesh) restore_custom_attrs(model, custom_attrs) + # model now has DTensors as its parameters + + if args.device == "meta": + # After FSDP2 has been applied, materialize and initialize the sharded parameters + # TE base.py's reset_parameters() handles DTensors with FP8 initialization + for module in model.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + dist_print(f" Sharded parameters materialized and initialized on cuda device.") + + dist_print( + f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" + ) optimizer = optim.Adam(model.parameters(), lr=1e-3) for iteration in range(args.iter): # Zero the parameter gradients optimizer.zero_grad() - input_data = torch.randn(args.batch_size, args.input_size).to(device) + input_data = torch.randn(inp_shape).to(device) with te.autocast(enabled=True, recipe=fp8_recipe): output = model(input_data) - target = torch.randn(args.batch_size, args.output_size).to(device) + target = torch.randn(out_shape).to(device) loss = F.mse_loss(output, target) loss.backward() optimizer.step() - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + dist_print(f"Iteration {iteration} completed with loss {loss.item()}") + + # Some of the FSDP states are lazy initialized during FSDP forward pass + # so testing fp8 allgather at the end of the training loop. + if args.fp8_init: + test_fp8_fsdp2_allgather(model) dist.destroy_process_group() - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Done...") return 0 diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index 8fe4e8bc7..91d6fc6ed 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -12,22 +12,26 @@ fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) - +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() -def _run_test(fp_init, sharding_dims): +def _run_test(fp_init, sharding_dims, recipe, layer_type): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] if fp_init: test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: test_cmd += ["--sharding-dims", str(sharding_dims[0])] elif len(sharding_dims) == 2: test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] else: assert False + test_cmd += ["--recipe", recipe] + test_cmd += ["--layer-type", layer_type] + result = subprocess.run(test_cmd, env=os.environ, check=True) @@ -36,16 +40,20 @@ def _run_test(fp_init, sharding_dims): @pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) -def test_distributed(fp8_init, sharding_dims): +@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) +@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) +def test_distributed(fp8_init, sharding_dims, recipe, layer_type): # Skip invalid configurations if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") - if fp8_init and not fp8_available: + if recipe == "mx_fp8_block_scaling" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + elif not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims) + _run_test(fp8_init, sharding_dims, recipe, layer_type) def test_dummy() -> None: diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e938509e5..620ea8301 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1886,6 +1886,43 @@ def allreduce( return inp, handle +def _get_module_fsdp_state(module): + """ + If module is an FSDP module, return its _FSDPState. + Otherwise, return the _FSDPState of the closest parent FSDP module + in the module hierarchy the module belongs to. + """ + + if hasattr(module, "_get_fsdp_state"): + # this will return correct fsdp state if module itself is an fsdp module + fsdp_state = module._get_fsdp_state() + elif getattr(module, "_te_cached_parent_fsdp_state", None) is not None: + # See if we have cached the parent fsdp state of the module + fsdp_state = module._te_cached_parent_fsdp_state + else: + from torch.distributed._composable_state import _module_state_mapping + + # Otherwise get the fsdp state of lca of module in the module hierarchy + min_nodes_in_parent = float("inf") + closest_parent_fsdp_mod = None + for fsdp_mod in _module_state_mapping.keys(): + all_submodules = list(fsdp_mod.modules()) + for submodule in all_submodules: + if submodule is module: + if min_nodes_in_parent > len(all_submodules): + closest_parent_fsdp_mod = fsdp_mod + min_nodes_in_parent = len(all_submodules) + if closest_parent_fsdp_mod is None: + raise RuntimeError( + "Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules." + ) + fsdp_state = closest_parent_fsdp_mod._get_fsdp_state() + # Cache the parent fsdp state of the module to avoid recomputing + # the closest parent fsdp module. + module._te_cached_parent_fsdp_state = fsdp_state + return fsdp_state + + def _fsdp_scatter_tensors( fsdp_group: dist_group_type, *tensors: torch.Tensor, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 9b6ca9d9c..d2abe3a2d 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -17,6 +17,7 @@ import torch import torch.nn.functional as F +from torch.distributed.tensor import DTensor import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe @@ -1244,7 +1245,12 @@ def register_parameter(self, name, param, **kwargs): metedata used in deferred initialization. """ super().register_parameter(name, param) - self.param_init_meta[name] = _ParameterInitMeta(**kwargs) + # Initialize param_init_meta exactly once during the init. FSDP2 can call + # register parameter again to change parameters to DTensors. And it calls + # it without custom fp8 specific kwargs that we need. And so we dont want + # to reset/loose our fp8 init attributes. + if hasattr(self, "param_init_meta") and name not in self.param_init_meta: + self.param_init_meta[name] = _ParameterInitMeta(**kwargs) def reset_parameters(self, defer_init: Optional[bool] = False) -> None: """ @@ -1256,10 +1262,14 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: return for name, param in self.named_parameters(recurse=False): + # Check if parameter is a DTensor (FSDP2) or regular tensor + is_dtensor = isinstance(param, DTensor) + dtensor_param = param if is_dtensor else None + # Need to update/quantize local tensor in case of DTensor + param = param._local_tensor if is_dtensor else param # Ensure parameter is on a real device if param.device == torch.device("meta"): param = torch.empty_like(param, device="cuda") - # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker @@ -1288,7 +1298,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False - + if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + device_mesh = dtensor_param.device_mesh + amax_reduction_group = ( + device_mesh.get_group(mesh_dim="shard") + if device_mesh.ndim > 1 + else device_mesh.get_group() + ) + quantizer.amax_reduction_group = amax_reduction_group + quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) @@ -1296,7 +1314,18 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - param = torch.nn.Parameter(param) + if is_dtensor: + # recreate the DTensor from the parameter. + dtensor_param = DTensor.from_local( + param, + device_mesh=dtensor_param.device_mesh, + placements=dtensor_param.placements, + shape=dtensor_param.size(), + stride=dtensor_param.stride(), + ) + dtensor_param = torch.nn.Parameter(dtensor_param) + else: + param = torch.nn.Parameter(param) # Keep high-precision values on CPU if needed if high_precision_init_val is not None: @@ -1324,8 +1353,12 @@ def clear(self): param._high_precision_init_val = high_precision_init_val param.get_high_precision_init_val = MethodType(get, param) param.clear_high_precision_init_val = MethodType(clear, param) + # Update the parameter based on its type - setattr(self, name, param) + if not is_dtensor: + setattr(self, name, param) + else: + setattr(self, name, dtensor_param) @abstractmethod def forward(self): diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4d6b2f23b..59dc2b299 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -108,9 +108,15 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) - if weight_quantizers[0] is not None: + # No need to set the quantizer states if weight is already quantized + if weight_quantizers[0] is not None and not isinstance( + weights[0], QuantizedTensorStorage + ): for weight_quantizer in weight_quantizers: weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + elif isinstance(weights[0], QuantizedTensorStorage): + # If weights are already quantized, no need to set quantizer states + weight_quantizers = [weight._quantizer for weight in weights] if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) @@ -205,10 +211,6 @@ def forward( inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms - if inp.requires_grad: - for weight in weights_fp8: - if isinstance(weight, QuantizedTensorStorage): - weight.update_usage(columnwise_usage=True) if cpu_offloading: ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") @@ -354,13 +356,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - - for weight, quantizer in zip(weights, ctx.weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorStorage): - weight.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights, grad_output, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 933c7cde5..abe8c5829 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -276,12 +276,15 @@ def forward( # Prepare weight tensor # ------------------------------------------------------ weightmat = weight - quantized_weight = False + is_weight_param_quantized = False if fp8 or debug: - quantized_weight = not isinstance(weight, QuantizedTensorStorage) + is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) # Configure quantizer - if weight_quantizer is not None: + # If weight is already quantized, no need to set quantizer states + if is_weight_param_quantized: + weight_quantizer = weight._quantizer + elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) # Get quantized weight @@ -413,10 +416,6 @@ def forward( ): ln_out.update_usage(rowwise_usage=False) - # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(weightmat, QuantizedTensorStorage): - weightmat.update_usage(columnwise_usage=True) - if cpu_offloading: mark_activation_offload(inputmat, mu, rsigma, ln_out) @@ -429,7 +428,7 @@ def forward( fsdp_group, mu, rsigma, - weightmat if quantized_weight else None, + weightmat if fp8 and not is_weight_param_quantized else None, ln_out if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -459,7 +458,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.requires_dgrad = inp_requires_grad ctx.requires_wgrad = weight.requires_grad - ctx.quantized_weight = quantized_weight + ctx.is_weight_param_quantized = is_weight_param_quantized if fuse_wgrad_accumulation and weight.requires_grad: # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates @@ -563,7 +562,7 @@ def backward( ctx.fsdp_shapes, mu, rsigma, - weight if ctx.fp8 and ctx.quantized_weight else None, + weight if ctx.fp8 and not ctx.is_weight_param_quantized else None, ln_out, ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 889f545c1..a358ae7dd 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -351,8 +351,17 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + # No need to set the quantizer states if weights are already quantized + if isinstance(fc1_weight, QuantizedTensorStorage): + fc1_weight_quantizer = fc1_weight._quantizer + elif fc1_weight_quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + + if isinstance(fc2_weight, QuantizedTensorStorage): + fc2_weight_quantizer = fc2_weight._quantizer + elif fc2_weight_quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -538,13 +547,6 @@ def forward( # Cache state for backward pass if is_grad_enabled: - - # Weight with column-wise usage is needed for dgrad GEMM. - if isinstance(fc1_weight_final, QuantizedTensorStorage): - fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorStorage): - fc2_weight_final.update_usage(columnwise_usage=True) - if cpu_offloading: mark_activation_offload( inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ccb84e664..0e2310a5a 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -240,7 +240,8 @@ def forward( weightmat = weight if fp8 or debug: # Configure quantizer - if weight_quantizer is not None: + # No need to set the quantizer states if weight is already quantized + if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad if not columnwise_usage: columnwise_usage = ( @@ -248,7 +249,9 @@ def forward( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - + elif isinstance(weight, QuantizedTensor): + # If weight is already quantized, no need to set quantizer states + weight_quantizer = weight._quantizer # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch weightmat = module.get_weight_workspace( @@ -389,11 +392,6 @@ def forward( if backward_needs_input: saved_inputmat = inputmat - # Weight with column-wise usage is needed for dgrad GEMM. - if inp.requires_grad: - if isinstance(weightmat, QuantizedTensorStorage): - weightmat.update_usage(columnwise_usage=True) - if cpu_offloading and saved_inputmat is not None: mark_activation_offload(saved_inputmat) diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 15f5b6bd5..7d49e3964 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -433,6 +433,10 @@ def maybe_update_inplace(arg, new_arg, schema_arg): and schema_arg.alias_info.is_write ): arg.quantize_(new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + # Recursively handle update for lists of tensors + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) # In-place op: dequantize, perform op, and quantize if func._schema.is_mutable: @@ -489,20 +493,16 @@ def make_like( shape: Optional[Iterable[int]] = None, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, - data: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Create new quantized tensor By default, new tensor has the same attributes and underlying - data. + data. This function is intended to create view of tensors. """ - if shape is None: - shape = data.shape if data is not None else tensor.shape + shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() - if data is not None: - kwargs["data"] = data return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index de112bb3f..eb2ac9a58 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -4,10 +4,10 @@ """Tensor class with FP8 data""" from __future__ import annotations -from typing import Optional, Tuple, Iterable, Union +from typing import Any, Optional, Tuple, Iterable, Union import warnings - import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -299,14 +299,12 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [data.size(-1)] + list(data.shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, ) - # Construct FP8 tensor return Float8Tensor( shape=shape, @@ -534,9 +532,36 @@ def remove_caches(self) -> None: self._transpose = None @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + data_transpose: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. - # View op + """ + if shape is None and data is not None: + shape = data.shape + new_tensor = super().make_like( + tensor, shape=shape, dtype=dtype, requires_grad=requires_grad + ) + if data is not None: + new_tensor._data = data + if data_transpose is not None: + new_tensor._transpose = data_transpose + new_tensor._transpose_invalid = False + return new_tensor + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == aten.view.default: tensor = args[0] data = tensor._data @@ -555,6 +580,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): or out_transpose_shape[1:] != out_shape[:-1] ): out_transpose = None + else: + view_shape_for_transpose = [out_shape[-1]] + list(out_shape[:-1]) + out_transpose = out_transpose.view(*view_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, @@ -587,11 +615,37 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [ - Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) - for split_tensor in func_out + t_func_out = [None] * len(func_out) + # Compute corresponding split of the transpose cache if available + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + ndim = data.dim() + # Figure out the original split dim + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + # Dimension along which transpose needs to be split + t_dim = 0 if dim_to_split == ndim - 1 else dim_to_split + 1 + t_func_out = transpose.__torch_dispatch__( + func, + types, + [transpose, args[1], t_dim], + kwargs, + ) + outs = [ + Float8Tensor.make_like( + tensor, + data=split_tensor, + data_transpose=split_transpose_tensor, + shape=split_tensor.shape, + ) + for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) ] + return outs + if func == aten.new_zeros.default: + # create fresh new tensor with zeros. tensor = args[0] data = tensor._data func_out = data.__torch_dispatch__( @@ -600,17 +654,63 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + t_shape = [size[-1]] + list(size[:-1]) + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_shape] + list(args[2:]), + kwargs, + ) + # deep copy the scale inverse tensor and quantizer as well. + scale_inv = tensor._scale_inv.detach().clone() + quantizer = tensor._quantizer.copy() + out_tensor = Float8Tensor( + data=func_out, + shape=func_out.shape, + dtype=tensor.dtype, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=scale_inv, + data_transpose=func_transposed_out, + quantizer=quantizer, + ) + return out_tensor + if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data + # Apply as_strided to the primary uint8 data func_out = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + stride = args[2] + if "storage_offset" in kwargs: + storage_offset = kwargs["storage_offset"] + else: + storage_offset = args[3] if len(args) > 3 else 0 + # Shape and strided needed for transpose matrix + t_size = [size[-1]] + list(size[:-1]) + t_stride = [stride[-1]] + list(stride[:-1]) + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_size, t_stride, storage_offset] + list(args[4:]), + kwargs, + ) + return Float8Tensor.make_like( + tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + ) + if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: @@ -632,9 +732,105 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) else: pass - return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor + (For us same as self.stride()) + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this FP8 tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered.(In this case uint8 data tensor) + metadata: Tuple[Any]: Metadata needed for reconstructing the + Float8Tensor after all-gather. + """ + # pylint: disable=unused-argument + # Importing here to avoid circular imports + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None: + # When sharded weight is updated after reduce scattering the gradients in FSDP2, + # we need to do amax reduction across the mesh to make sure all weight shards are + # updated with same scale inverse. Setting the state below in the quantizer will make + # sure that updated Quantized weight tensor have same scale inverse across all shards. + self._quantizer.amax_reduction_group = mesh.get_group() + self._quantizer.with_amax_reduction = True + quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + # If weights are resharded after forward pass, then its enough to set the quantizer usages + # based on whether its forward or backward pass for the allgathered weights. + # If not resharded after forward pass, the same weights allgathered in forward + # are used again in backward and so we dont change the quantizer usages which might need + # both rowwise and columnwise usages. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + # In case of hopper/L40, only one of data/transpose is needed + # based on forward or backward pass. So setting the quantizer usages appropriately. + quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) + sharded_tensors = (self._data,) + metadata = (self._scale_inv, self._fp8_dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Float8Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor. + param_dtype (torch.dtype): high precision dtype of the Float8Tensor. + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + + Returns: + Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors + used by the Float8Tensor that was being computed after allgather. + """ + + (data,) = all_gather_outputs + (fp8_scale_inv, fp8_dtype, quantizer) = metadata + orig_shape = data.size() + # Quantizer has only columnwise usage set for backward pass + # In Blackwell+ architectures, transpose is not needed at all, + # even if columnwise usage is set. and is going to be handled + # internally in the update_usage method. + if out is not None: + out._data = data + else: + fp8_args = { + "shape": orig_shape, + "dtype": param_dtype, + "fp8_scale_inv": fp8_scale_inv, + "fp8_dtype": fp8_dtype, + "quantizer": quantizer, + "requires_grad": False, + "data": data, + } + out = Float8Tensor(**fp8_args) + + out.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) + return out, all_gather_outputs + @classmethod def _make_in_reduce_ex( cls, @@ -752,6 +948,9 @@ def forward( out_transpose_shape = out_transpose.size() if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: out_transpose = None + else: + view_shape_for_transpose = [shape[-1]] + list(shape[:-1]) + out_transpose = out_transpose.view(*view_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, @@ -796,6 +995,9 @@ def forward( out_transpose_shape = out_transpose.size() if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: out_transpose = None + else: + reshape_shape_for_transpose = [shape[-1]] + list(shape[:-1]) + out_transpose = out_transpose.reshape(*reshape_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 5ef5708fd..d981f7157 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -6,16 +6,17 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, Any +import warnings import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple - from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc @@ -298,7 +299,6 @@ def contiguous( memory_format: torch.memory_format = torch.contiguous_format, ) -> MXFP8Tensor: """Returns tensor with data in provided memory format - Returns `self` if data is already in correct memory format. """ @@ -314,7 +314,6 @@ def contiguous( @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # View op if func == aten.view.default: tensor = args[0] @@ -338,9 +337,335 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, ) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. + # If not, default to base class behavior. + rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None + columnwise_matches = src._columnwise_data is not None or dst._columnwise_data is None + if ( + isinstance(src, MXFP8Tensor) + and isinstance(dst, MXFP8Tensor) + and rowwise_matches + and columnwise_matches + ): + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data.detach()) + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data.detach()) + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) + return dst + + # FSDP2 related functions. + if func == aten.split.Tensor: + # This is called if entire model is initialized on CUDA device and + # then splitted. Finally the shard needed by the process is used + # and other splitted shards are discarded. + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + tensor = args[0] + split_size = args[1] + dim0_size = tensor.size(0) + dimlast_size = math.prod(tensor.shape[1:]) + if ( + dim0_size % split_size != 0 + or dim_to_split != 0 + or split_size % MXFP8_BLOCK_SCALING_SIZE != 0 + or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + # Handle splitting by dequantizing and splitting the hp tensor + return super().__torch_dispatch__(func, types, args, kwargs) + + out_data = [] + for data in [tensor._rowwise_data, tensor._columnwise_data]: + func_out = ( + data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + if data is not None + else None + ) + out_data.append(func_out) + + scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv] + split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE] + # Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4 + padding_multiples = [128, 4] + for scale_inv, scale_split_size, pad_multiple in zip( + scale_invs, split_sizes_for_scale, padding_multiples + ): + scale_inv_out = ( + scale_inv.__torch_dispatch__( + func, + types, + [scale_inv, scale_split_size] + list(args[2:]), + kwargs, + ) + if scale_inv is not None + else None + ) + # Pad scale_inv_out to be a multiple of pad_multiple + if scale_inv_out is not None: + current_shape = scale_inv_out.shape + pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple + if pad_dim0 > 0: + scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0)) + + out_data.append(scale_inv_out) + return [ + MXFP8Tensor( + shape=( + splitted_tensor_data[0].size() + if splitted_tensor_data[0] is not None + else splitted_tensor_data[1].size() + ), + dtype=tensor.dtype, + rowwise_data=splitted_tensor_data[0], + rowwise_scale_inv=splitted_tensor_data[2], + columnwise_data=splitted_tensor_data[1], + columnwise_scale_inv=splitted_tensor_data[3], + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + for splitted_tensor_data in zip(*out_data) + ] + if func == torch.ops.aten.as_strided.default: + # Applied on unsharded param in FSDP2. In our case, this should be a no-op + # This is needed for the case where some MXFP8 shards need padding i.e dimension 0 + # of the unsharded param is not a multiple of the world size. If that is the case, + # we down the dequantization route and weights are allgathered in high precision. + # If weight doesnt need padding, this is just a no-op. + shape = args[1] + strides = args[2] + tensor = args[0] + if ( + len(shape) != 2 + or len(strides) != 2 + or strides[1] != 1 + or shape[0] != tensor.shape[0] + or shape[1] != tensor.shape[1] + ): + return super().__torch_dispatch__(func, types, args, kwargs) + + return MXFP8Tensor.make_like(tensor) + + if func == aten.slice.Tensor: + # FSDP2 needed function. + # We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0 + # of the unsharded param is not a multiple of the world size. If that is the case, + # we down the dequantization route and weights are allgathered in high precision instead. + # If sharded weight doesnt have padding, this is just a no-op. + dim = args[1] + start = args[2] + length = args[3] + tensor = args[0] + if ( + dim != 0 + or length != tensor.shape[0] + or start != 0 + or length % MXFP8_BLOCK_SCALING_SIZE != 0 + or start % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + return super().__torch_dispatch__(func, types, args, kwargs) + return MXFP8Tensor.make_like(tensor) + + if func == aten.new_zeros.default: + rowwise_data = None + columnwise_data = None + rowwise_scale_inv = None + columnwise_scale_inv = None + tensor = args[0] + shape = args[1] + first_dim = math.prod(shape[:-1]) + last_dim = shape[-1] + if ( + first_dim % MXFP8_BLOCK_SCALING_SIZE != 0 + or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + return super().__torch_dispatch__(func, types, args, kwargs) + rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE] + columnwise_scale_inv_shape = [ + first_dim // MXFP8_BLOCK_SCALING_SIZE, + last_dim, + ] + if tensor._rowwise_data is not None: + rowwise_data = tensor._rowwise_data.__torch_dispatch__( + func, + types, + [tensor._rowwise_data] + list(args[1:]), + kwargs, + ) + rowwise_scale_inv = tensor._rowwise_scale_inv.__torch_dispatch__( + func, + types, + [tensor._rowwise_scale_inv, rowwise_scale_inv_shape] + list(args[2:]), + kwargs, + ) + if tensor._columnwise_data is not None: + columnwise_data = tensor._columnwise_data.__torch_dispatch__( + func, + types, + [tensor._columnwise_data] + list(args[1:]), + kwargs, + ) + columnwise_scale_inv = tensor._columnwise_scale_inv.__torch_dispatch__( + func, + types, + [tensor._columnwise_scale_inv, columnwise_scale_inv_shape] + list(args[2:]), + kwargs, + ) + return MXFP8Tensor( + shape=args[1], + dtype=tensor.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=tensor._quantizer.copy(), + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) # Default case return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor + (For us same as self.stride()). + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this MXFP8 tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered. + metadata: Tuple[Any]: Metadata needed for reconstructing the + MXFP8Tensor after all-gather. + """ + # pylint: disable=unused-argument + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + quantizer = self._quantizer.copy() + # Remove padding from scale inverses before allgather + # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] + rowwise_scale_inv = self._rowwise_scale_inv + columnwise_scale_inv = self._columnwise_scale_inv + shape = self.shape + if rowwise_scale_inv is not None: + # Remove padding from rowwise scale_inv + flattened_in_shape0 = math.prod(shape[:-1]) + if rowwise_scale_inv.size(0) != flattened_in_shape0: + rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] + + if columnwise_scale_inv is not None: + # Remove padding from columnwise scale_inv + flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE + if columnwise_scale_inv.size(0) != flattened_in_shape0: + columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0] + + sharded_tensors = (self._rowwise_data, rowwise_scale_inv) + # If weights are resharded after forward pass, then its enough to set the quantizer usages + # based on whether its forward or backward pass for the allgathered weights. + # If not resharded after forward pass, the same weights allgathered in forward + # are used again in backward. And hence if we need the columnwise data/scale_inv, + # we need to send them as well for allgather in forward pass itself. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + # Allgather only the necessary tensors based on forward/backward pass + quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) + sharded_tensors = ( + (self._columnwise_data, columnwise_scale_inv) + if is_backward_pass + else sharded_tensors + ) + else: + if quantizer.columnwise_usage: + # If weights are not resharded after forward, then both + # rowwise and columnwise data/scale_inv need to be allgathered. + sharded_tensors += (self._columnwise_data, columnwise_scale_inv) + metadata = (self._fp8_dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[MXFP8Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor. + param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor. + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + Returns: + Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors + used by the MXFP8Tensor that was being computed after allgather. + """ + fp8_dtype, quantizer = metadata + rowwise_data, rowwise_scale_inv = ( + all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None) + ) + columnwise_data, columnwise_scale_inv = ( + all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None) + ) + + # Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise + if rowwise_scale_inv is not None: + # Pad rowwise_scale_inv to be a multiple of [128, 4] + current_shape = rowwise_scale_inv.shape + pad_dim0 = (128 - current_shape[0] % 128) % 128 + if pad_dim0 > 0: + rowwise_scale_inv = torch.nn.functional.pad(rowwise_scale_inv, (0, 0, 0, pad_dim0)) + + if columnwise_scale_inv is not None: + # Pad columnwise_scale_inv to be a multiple of [4, 128] + current_shape = columnwise_scale_inv.shape + pad_dim0 = (4 - current_shape[0] % 4) % 4 + if pad_dim0 > 0: + columnwise_scale_inv = torch.nn.functional.pad( + columnwise_scale_inv, (0, 0, 0, pad_dim0) + ) + + if out is not None: + out._rowwise_data = rowwise_data + out._rowwise_scale_inv = rowwise_scale_inv + out._columnwise_data = columnwise_data + out._columnwise_scale_inv = columnwise_scale_inv + out._quantizer = quantizer + else: + out = MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=fp8_dtype, + dtype=param_dtype, + shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + quantizer=quantizer, + ) + + return out, all_gather_outputs + @classmethod def _make_in_reduce_ex( cls, @@ -478,10 +803,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != ctx.shape[-1]: - raise RuntimeError( - "MXFP8Tensor does not support reshaping inner dimension " + warnings.warn( + "MXFP8Tensor does not support reshaping inner dimension. " f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + "If you are using this for FSDP2 without compiled_autograd_enabled," + "then ignore this warning. Since this view is not going to be used anywhere. ", + stacklevel=2, ) + return tensor.dequantize().view(*shape) # Construct new tensor if shape is provided new_rowwise_data = None From f8693d2b044ab83624e020ffa69f3412a916993d Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Wed, 12 Nov 2025 07:38:41 -0800 Subject: [PATCH 115/141] Fix CI failure related to bug in MXFP8 copy implementation (#2369) * fix ci issue Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert back testing changes Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/tensor/mxfp8_tensor.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index d981f7157..15e0b86c9 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -339,23 +339,21 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] - # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. - # If not, default to base class behavior. - rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None - columnwise_matches = src._columnwise_data is not None or dst._columnwise_data is None - if ( - isinstance(src, MXFP8Tensor) - and isinstance(dst, MXFP8Tensor) - and rowwise_matches - and columnwise_matches - ): - if dst._rowwise_data is not None: - dst._rowwise_data.copy_(src._rowwise_data.detach()) - dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) - if dst._columnwise_data is not None: - dst._columnwise_data.copy_(src._columnwise_data.detach()) - dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) - return dst + if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): + # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. + # If not, default to base class behavior. + rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None + columnwise_matches = ( + src._columnwise_data is not None or dst._columnwise_data is None + ) + if rowwise_matches and columnwise_matches: + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data.detach()) + dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_(src._columnwise_data.detach()) + dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) + return dst # FSDP2 related functions. if func == aten.split.Tensor: From e4bfa628632e15ef8bc1fae9b2e89686f6a097ea Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 12 Nov 2025 23:47:41 +0530 Subject: [PATCH 116/141] [Feature] Enable rope application with offsets for training (#2188) * enable applying rope offsets in backwared Signed-off-by: Sudhakar Singh * add tests for rope offsets for thd/bshd/sbhd formats Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * minor fixes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Sudhakar Singh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fused_rope.py | 170 ++++++++++++++---- .../common/fused_rope/fused_rope.cu | 85 ++++----- .../include/transformer_engine/fused_rope.h | 13 +- transformer_engine/pytorch/attention/rope.py | 85 ++++----- transformer_engine/pytorch/csrc/extensions.h | 1 + .../pytorch/csrc/extensions/apply_rope.cpp | 17 +- 6 files changed, 236 insertions(+), 135 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index aaf2eca2d..9e4ddbdad 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -58,10 +58,6 @@ def test_fused_rope( # are with the maximum length of the rope embeddings. pytest.skip("Skipping test with margin=0 and start_positions=True") - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( @@ -102,11 +98,8 @@ def test_fused_rope( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -121,17 +114,12 @@ def test_fused_rope( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) - - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - + torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @@ -156,10 +144,6 @@ def test_fused_rope_thd( margin: int, ) -> None: - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] @@ -214,10 +198,8 @@ def test_fused_rope_thd( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -233,18 +215,142 @@ def test_fused_rope_thd( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) + assert output_fused.is_contiguous() - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - assert output_fused.is_contiguous() +@pytest.mark.parametrize("start_positions", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("rotary_percent", [1.0]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [2]) +@pytest.mark.parametrize("interleaved", [False, True]) +def test_unfused_rope_thd_vs_bshd( + dtype: torch.dtype, + hidden_size: int, + rotary_percent: float, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + """ + This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD + formats are the same. + """ + device = torch.device("cuda:0") + seqlen, max_seqlen = 16, 2048 + batch_size, head_num = 4, 256 + + # NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and + # that causes unexpected issues. + seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32) + + cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to( + device=device, dtype=torch.int32 + ) + + # Create a tensor in THD format + thd = torch.rand( + (cu_seqlens[-1] // cp_size, head_num, hidden_size), + dtype=dtype, + device=device, + ) + thd.requires_grad = True + + # Clone the tensor to create a tensor in BSHD format + bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach() + bshd = bshd.to(dtype=dtype, device=device) + bshd.requires_grad = True + + # Clone the tensor to create a tensor in SBHD format + sbhd = bshd.transpose(1, 0).clone().detach() + sbhd = sbhd.to(dtype=dtype, device=device) + sbhd.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb = rotary_pos_emb(max_seqlen) + assert emb.is_contiguous() + + start_positions = cu_seqlens[:-1] if start_positions else None + + for cp_rank in range(cp_size): + # unfused bshd + output_unfused_bshd = apply_rotary_pos_emb( + bshd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="bshd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + loss_unfused_bshd = loss_func(output_unfused_bshd) + loss_unfused_bshd.backward() + grad_unfused_bshd = bshd.grad.detach().clone() + bshd.grad = None + + # unfused sbhd + output_unfused_sbhd = apply_rotary_pos_emb( + sbhd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="sbhd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_sbhd = loss_func(output_unfused_sbhd) + loss_unfused_sbhd.backward() + grad_unfused_sbhd = sbhd.grad.detach().clone() + sbhd.grad = None + + # unfused thd + output_unfused_thd = apply_rotary_pos_emb( + thd.float(), + emb, + start_positions=start_positions, + tensor_format="thd", + interleaved=interleaved, + fused=False, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_thd = loss_func(output_unfused_thd) + loss_unfused_thd.backward() + grad_unfused_thd = thd.grad.detach().clone() + thd.grad = None + + torch.testing.assert_close( + output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd + ) + torch.testing.assert_close( + output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape), + output_unfused_thd, + ) + torch.testing.assert_close( + grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + torch.testing.assert_close( + grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + + assert output_unfused_thd.is_contiguous() + assert output_unfused_bshd.is_contiguous() + assert output_unfused_sbhd.is_contiguous() @pytest.mark.parametrize("start_positions", [True, False]) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index ccd0bc44c..597a5d3c2 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -155,18 +155,18 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; - s_id_for_freqs = s_id + begin_offset; } fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq template __global__ void fused_rope_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, - const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, - const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, - const int o_stride_d) { + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s, + const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block, offset_block_dst; int cur_seqlens; @@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel( cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - s_id_for_freqs = s_id; } fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c template void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const float *freqs, const int *start_positions, + scalar_t *input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); @@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, - stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, - o_stride_d); + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, + s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, + o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *input_grads, const NVTE_QKV_Format qkv_format, - const bool interleaved, const int cp_size, const int cp_rank, const int s, - const int b, const int h, const int d, const int d2, - const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, cudaStream_t stream) { + const Tensor &start_positions, Tensor *input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);); @@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), - *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads), - qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream); + *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size, + cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 610868f93..19047f463 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] qkv_format QKV format. * \param[in] interleaved Whether to use interleaved rotary position embedding. @@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream); + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream); /*! \brief Apply rotary positional embedding to the combined QKV input tensor. * diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index cc23d65a3..0e1222c22 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -149,7 +149,7 @@ def forward( cp_size, cp_rank, ) - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank @@ -160,10 +160,11 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" - freqs, cu_seqlens = ctx.saved_tensors + freqs, cu_seqlens, start_positions = ctx.saved_tensors grad_input = tex.fused_rope_backward( grad_output, freqs, + start_positions, QKVFormat[ctx.tensor_format], ctx.interleaved, cu_seqlens, @@ -171,7 +172,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.cp_rank, ) - return grad_input, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None class FusedQKVRoPEFunc(torch.autograd.Function): @@ -278,7 +279,6 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _apply_rotary_pos_emb_base( t: torch.Tensor, freqs: torch.Tensor, - start_positions: torch.Tensor = None, tensor_format: str = "sbhd", interleaved: bool = False, ) -> torch.Tensor: @@ -291,45 +291,19 @@ def _apply_rotary_pos_emb_base( Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional embedding will be applied. freqs: torch.Tensor - Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', - with `s2 >= s` and `d2 <= d`. - start_positions: torch.Tensor, default = None. - Tokens in a sequence `i` should be applied with position encoding offset by - `start_positions[i]`. If `start_positions=None`, there's no offset. + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]` + and dtype 'float', with `s2 >= s` and `d2 <= d`. tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape `[seq, bs, ...]`. interleaved: bool, default = False Whether to use interleaved rotary position embedding. """ - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # In case `start_positions` are provided, create a staggered `freqs` tensor - # offset by the values in `start_positions`. - # `start_positions` is only supported for `cp_size=1` and inference. - if start_positions is not None: - max_offset = torch.max(start_positions) - assert ( - max_offset + cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!" - - # Stack staggered rope embeddings along the batch dimension - freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) - - # Note that from this point, `freqs` has a shape `(s,b,1,d)`. - - # Only apply the rotary embeddings up to the sequence length of the running - # input. - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - # [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, b, 1, dim] -> [b, seq, 1, dim] if tensor_format == "bshd": freqs = freqs.transpose(0, 1) + # cos/sin first then dtype conversion for better precision cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) @@ -366,7 +340,7 @@ def _get_freqs_on_this_cp_rank( ) # cp_size == 1 - return freqs + return freqs[:seqlen] def apply_rotary_pos_emb( @@ -388,13 +362,13 @@ def apply_rotary_pos_emb( Training: qkv_formats: "thd", "bshd", "sbhd" context parallel: yes - start_positions: no + start_positions: yes interleaving: yes Inference: qkv_formats: "thd", "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- @@ -423,22 +397,17 @@ def apply_rotary_pos_emb( cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ - - # `start_positions` is only supported for `cp_size=1` and inference. - assert not ( - cp_size > 1 and start_positions is not None - ), """start_positions != None with CP SIZE > 1 is not supported!""" - assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." + # Fused apply rope logic for THD/BSHD/SBHD formats if fused: return FusedRoPEFunc.apply( t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank ) - # Unfused THD format + # Unfused apply rope logic for THD format if tensor_format == "thd": cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -447,15 +416,18 @@ def apply_rotary_pos_emb( # `s1hd` tensors (for each sequence) and applies rotary embedding to # those sequences individually. # Note that if `start_positions` is not `None`, then for each sequence, - # it's corresponding rope offset is also supplied from `start_positions` - # individually. + # the freqs supplied are offset by the corresponding `start_positions` value. return torch.cat( [ _apply_rotary_pos_emb_base( x.unsqueeze(1), - _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), - start_positions=( - start_positions[idx : idx + 1] if start_positions is not None else None + _get_freqs_on_this_cp_rank( + ( + freqs[start_positions[idx] :] if start_positions is not None else freqs + ), # offset the freqs + x.size(0), + cp_size, + cp_rank, ), interleaved=interleaved, ) @@ -463,17 +435,28 @@ def apply_rotary_pos_emb( ] ).squeeze(1) - # Unfused SBHD/BSHD format + # Unfused apply rope logic for SBHD/BSHD format follows ... + if tensor_format == "sbhd": seqlen = t.size(0) elif tensor_format == "bshd": seqlen = t.size(1) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + + if start_positions is not None: + max_offset = torch.max(start_positions) + assert ( + max_offset + seqlen * cp_size <= freqs.shape[0] + ), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!" + + # Stack staggered rope embeddings along the batch dimension + freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1) + # Note that from this point, `freqs` has a shape `(s,b,1,d)`. + return _apply_rotary_pos_emb_base( t, _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), - start_positions, tensor_format, interleaved=interleaved, ) @@ -505,7 +488,7 @@ def apply_fused_qkv_rotary_pos_emb( qkv_formats: "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 43eab9654..77fb34858 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -346,6 +346,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank); diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 064da8a67..d1dcf68c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -163,6 +163,7 @@ std::tuple fused_qkv_rope_forward( } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank) { @@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); + } + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); @@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, - max_s, b, h, d, d2, stride_t, + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; @@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, - h, d, d2, stride_s, stride_b, stride_h, stride_d, - at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; } From c544ced2ea3c06950be9c33ad1802c831e44ff58 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 12 Nov 2025 18:55:14 -0500 Subject: [PATCH 117/141] [JAX] Relax tolerance for the test_multiprocessing_encoder.py with NVFP4 by 0.001 (#2375) relax tol Signed-off-by: Phuong Nguyen --- examples/jax/encoder/test_multiprocessing_encoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 4d2141116..f3092278e 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -672,7 +672,7 @@ def test_te_mxfp8(self): def test_te_nvfp4(self): """Test Transformer Engine with NVFP4""" result = self.exec(True, "NVFP4BlockScaling") - assert result[0] < 0.451 and result[1] > 0.788 + assert result[0] < 0.451 and result[1] > 0.787 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): @@ -710,7 +710,7 @@ def test_te_mxfp8_shardy(self): def test_te_nvfp4_shardy(self): """Test Transformer Engine with NVFP4""" result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) - assert result[0] < 0.451 and result[1] > 0.788 + assert result[0] < 0.451 and result[1] > 0.787 if __name__ == "__main__": From d8f1e68f7c414f3e7985a8b41de4443b2f819af3 Mon Sep 17 00:00:00 2001 From: Lifu Zhang Date: Wed, 12 Nov 2025 16:34:25 -0800 Subject: [PATCH 118/141] fix gradient accumulation fusion for FSDP (#2371) Signed-off-by: Lifu Zhang Co-authored-by: Lifu Zhang Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/module/grouped_linear.py | 6 +++--- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++-- transformer_engine/pytorch/module/linear.py | 4 ++-- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 59dc2b299..f336a743d 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -293,9 +293,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], origin_weights[i] = ctx.weight_objects[i] ctx.weight_objects[i] = None - if ctx.fuse_wgrad_accumulation: - for i in range(N): - origin_weights[i].main_grad = main_grads[i] + if ctx.fuse_wgrad_accumulation: + for i in range(N): + origin_weights[i].main_grad = main_grads[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index abe8c5829..20a67cba4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -572,8 +572,8 @@ def backward( if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: origin_weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - origin_weight.main_grad = main_grad + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + origin_weight.main_grad = main_grad # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 0e2310a5a..46b9dbd85 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -508,8 +508,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - weight.main_grad = main_grad + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already From d0d4063130e3ae8b40e557919eb04bc76b721c0c Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 13 Nov 2025 14:28:09 +0100 Subject: [PATCH 119/141] [PyTorch] Fix amax computation using output_t data in normalization (#2355) Fix amax computation using output_t data in normalization Signed-off-by: Evgeny --- tests/cpp/operator/test_normalization.h | 12 +++++++++++- .../normalization/layernorm/ln_fwd_kernels.cuh | 18 ++++++++++++++++-- .../rmsnorm/rmsnorm_fwd_kernels.cuh | 18 ++++++++++++++++-- 3 files changed, 43 insertions(+), 5 deletions(-) diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index fe69852d0..271345686 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -114,8 +114,18 @@ void compute_ref_output(NormType norm_type, tmp = current * rsigma[i] * g; } + // Write output (scaled only for fp8 paths) output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); + + // amax semantics: + // - fp8_out (scale != 1): amax on pre-scale compute value 'tmp' + // - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16) + if (scale != 1.f) { + current_max = fmaxf(current_max, fabsf(tmp)); + } else { + OutputType out_t_val = static_cast(tmp); + current_max = fmaxf(current_max, fabsf(static_cast(out_t_val))); + } } } diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 6050b164d..38c409607 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -123,7 +123,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( if (requires_amax) { __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(temp_output)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(temp_output); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } } if (params.fp8_out) { temp_output = temp_output * scale; @@ -290,7 +297,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(z_ij)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(z_ij); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } if (params.fp8_out) { z.data.elt[jt] = z_ij * scale; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index fc093b73a..7fed7f123 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -115,7 +115,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke if (requires_amax) { __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(temp_output)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(temp_output); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } } if (params.fp8_out) { temp_output = temp_output * scale; @@ -265,7 +272,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(z_ij)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(z_ij); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } if (params.fp8_out) { z.data.elt[jt] = z_ij * scale; } From ef28c86582b5539655e7286366c699bec99284bc Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 13 Nov 2025 12:30:59 -0500 Subject: [PATCH 120/141] [JAX] NVFP4 scale swizzling via nvte kernel (#2350) * swizzle via nvte Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 5 ++ .../jax/csrc/extensions/gemm.cpp | 62 ++++++++++++++----- 2 files changed, 50 insertions(+), 17 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 72bee251c..9ffec2c6a 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -533,6 +533,9 @@ def _dims_are_consecutive(dims): # Declare cuBLAS workspace workspace_size = get_cublas_workspace_size_bytes() + # NVFP4 swizzling happen in via nvte kernel instead of JAX transposes + if scaling_mode.is_nvfp4_scaling: + workspace_size += lhs_scale_inv.size + rhs_scale_inv.size if not collective_op.is_none: workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not @@ -662,6 +665,8 @@ def impl( rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis ) + # Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel + if scaling_mode.is_mxfp8_scaling: lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 8a3658a0b..6566ff168 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -34,8 +34,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { } std::tuple> xla_buffer_to_nvte_gemm_operand( - cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, - size_t axis_boundary, bool rowwise) { + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, uint8_t *swizzle_scale_ptr, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { // Set tensor data with collapsed 2D shape auto buffer_dims = buffer.dimensions(); std::vector input_shape = {product(buffer_dims, 0, axis_boundary), @@ -56,17 +56,32 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); std::vector scale_shape = {1}; - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + auto is_nvfp4 = is_nvfp4_scaling(scaling_mode); + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) { // Block scaling also needs to be collapsed to match 2D data scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + NVTE_CHECK(typeToSize(scale_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); } - - auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - if (rowwise) { + if (!is_nvfp4) { + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + } else { // Swizzle for NVFP4 + NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + // Create tensor to hold swizzled scale factor + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + // Set swizzled scales into the input tensor + input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); } } @@ -145,16 +160,34 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator, JAXX_Collective_Op collective_op) { + // cuBLAS workspace + 256 alignment enforcement (+ swizzle scales) + uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr; + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + size_t workspace_size = static_cast(workspace->element_count()) - 256; + if (is_nvfp4_scaling(scaling_mode)) { + auto lhs_scale_size = product(lhs_scale_inv.dimensions()); + auto rhs_scale_size = product(rhs_scale_inv.dimensions()); + workspace_size = workspace_size - lhs_scale_size - rhs_scale_size; + lhs_swizzle_scale_ptr = workspace_ptr; + rhs_swizzle_scale_ptr = workspace_ptr + lhs_scale_size; + workspace_ptr = rhs_swizzle_scale_ptr + rhs_scale_size; + } + auto workspace_ = TensorWrapper(workspace_ptr, std::vector{workspace_size}, DType::kByte); + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode, - lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, - rhs_axis_boundary, make_rhs_rowwise); + + auto [lhs_, lhs_shape] = + xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr, + scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = + xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr, + scaling_mode, rhs_axis_boundary, make_rhs_rowwise); std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; @@ -191,11 +224,6 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); - // cuBLAS workspace + 256 alignment enforcement - auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); - workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); - std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; - auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); float one = 1.; From 9440b76aa8d4333d1347c44f415c5789f1e038be Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 13 Nov 2025 14:05:18 -0500 Subject: [PATCH 121/141] [JAX] Shardy rule + QuantizeLayout Rework (#2364) * shardy + quantize_layout rework Signed-off-by: Phuong Nguyen * add assertion for NVFP4 in fused act and fused norm primitive Signed-off-by: Phuong Nguyen * add assertions Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 170 +++++++++--------- transformer_engine/jax/cpp_extensions/misc.py | 8 +- .../jax/cpp_extensions/normalization.py | 109 ++++++----- .../jax/cpp_extensions/quantization.py | 90 +++++----- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/activation.cpp | 45 ++--- transformer_engine/jax/csrc/extensions/misc.h | 14 +- .../jax/csrc/extensions/normalization.cpp | 29 ++- .../jax/csrc/extensions/pybind.cpp | 9 +- .../jax/csrc/extensions/quantization.cpp | 40 ++--- transformer_engine/jax/quantize/__init__.py | 1 + transformer_engine/jax/quantize/misc.py | 61 +++++++ transformer_engine/jax/quantize/quantizer.py | 55 ++---- .../jax/quantize/scaling_modes.py | 130 ++++++++++---- transformer_engine/jax/quantize/tensor.py | 36 ++-- 15 files changed, 456 insertions(+), 347 deletions(-) create mode 100644 transformer_engine/jax/quantize/misc.py diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index bb3c56bcf..aa84fafd3 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -10,7 +10,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.sharding import PartitionSpec import numpy as np @@ -159,7 +159,7 @@ class ActLuPrimitive(BasePrimitive): 11, 12, 13, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + ) # out_dtype, act_enum, act_len, scaling_mode, quantize_layout, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer inner_primitive = None outer_primitive = None @@ -173,7 +173,7 @@ def abstract( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -201,6 +201,13 @@ def abstract( "Current tensor scaling is not yet supported for fused activation and quantization." " Please do activation in higher-precision then quantize with current tensor scaling." ) + assert not ScalingMode(scaling_mode).is_nvfp4_scaling, ( + "NVFP4 block scaling is not yet supported for fused activation and quantization." + " Please do activation in higher-precision then quantize with current tensor scaling." + ) + assert ( + not quantize_layout.is_colwise_only + ), "Fused activation with colwise-only quantization is not supported." out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) @@ -210,7 +217,7 @@ def abstract( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1) - if not is_2x: + if quantize_layout.is_rowwise_only: out_shape = (1,) colwise_scale_inv_shape = (1,) colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) @@ -232,7 +239,7 @@ def lowering( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -259,7 +266,7 @@ def lowering( amax, act_enum=act_enum, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, act_params=act_params.to_ffi_lowering_dict(), output_amax_when_no_scaling=output_amax_when_no_scaling, ) @@ -274,7 +281,7 @@ def impl( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -297,7 +304,7 @@ def impl( act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, act_params=act_params, amax_scope=amax_scope, @@ -313,7 +320,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -329,7 +336,7 @@ def batcher( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -356,7 +363,7 @@ def batcher( act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, act_params=act_params, amax_scope=amax_scope, @@ -373,7 +380,7 @@ def infer_sharding_from_operands( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -402,7 +409,7 @@ def infer_sharding_from_operands( out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: @@ -419,7 +426,7 @@ def infer_sharding_from_operands( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -444,7 +451,7 @@ def partition( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -462,7 +469,7 @@ def partition( out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: @@ -479,7 +486,10 @@ def partition( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: + assert not ScalingMode( + scaling_mode + ).is_colwise_transposed, "Transpose layout scaling modes are not supported here yet" colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -514,7 +524,7 @@ def sharded_impl(x, scale, amax): act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, act_params=act_params, amax_scope=amax_scope, @@ -550,7 +560,7 @@ def shardy_sharding_rule( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, act_params, amax_scope, @@ -574,37 +584,28 @@ def shardy_sharding_rule( mesh, result_types, ) - prefix = "ActLu_" + prefix = "ActLu" input_shape = value_types[0].shape output_shape = input_shape[:-2] + input_shape[-1:] # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - output_shape, unique_var=prefix + "x", flatten_axis=-1 + output_shape, unique_var=prefix, flatten_axis=-1, q_layout=quantize_layout ) - x_axes = scale_rules.input_spec - # Correct input spec with act dim - x_axes = x_axes[:-1] + (prefix + "_act_dim",) + x_axes[-1:] - out = scale_rules.input_spec - - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "scale_inv_colwise",) - if is_2x: - colwise_scale_inv = scale_rules.colwise_rule - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = multidim_transpose(out, transpose_axis=-1) - else: - colwise_out = out - colwise_scale_inv = scale_rules.colwise_rule - - amax = (prefix + "amax",) + # Correct the input spec with act dim + input_spec = scale_rules.input_spec + input_spec = input_spec[:-1] + (prefix + "_act_dim",) + input_spec[-1:] + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) return SdyShardingRule( + (tuple(input_spec), scale, amax), ( - x_axes, - ("…1",), + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, amax, ), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax), **scale_rules.factor_sizes, ) @@ -612,7 +613,6 @@ def shardy_sharding_rule( register_primitive(ActLuPrimitive) -# TODO(Jeremy): replace is_2x with q_layout class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive @@ -620,7 +620,7 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + # out_dtype, scaling_mode, quantize_layout, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -634,7 +634,7 @@ def abstract( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -678,7 +678,7 @@ def abstract( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) - if is_2x: + if quantize_layout.is_rowwise_colwise: if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: @@ -700,7 +700,7 @@ def abstract( jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), scaling_mode, - is_2x, + quantize_layout.value, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -741,7 +741,7 @@ def lowering( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -777,7 +777,7 @@ def lowering( scale, amax, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, is_dbias=is_dbias, act_enum=int(act_enum), act_params=act_params.to_ffi_lowering_dict(), @@ -792,7 +792,7 @@ def impl( amax, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -816,7 +816,7 @@ def impl( amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, @@ -835,7 +835,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -848,7 +848,7 @@ def batcher( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -883,7 +883,7 @@ def batcher( amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, @@ -901,7 +901,7 @@ def batcher( def infer_sharding_from_operands( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -928,7 +928,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: @@ -954,7 +954,7 @@ def infer_sharding_from_operands( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -981,7 +981,7 @@ def infer_sharding_from_operands( def partition( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -1003,7 +1003,7 @@ def partition( mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: @@ -1029,7 +1029,7 @@ def partition( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -1066,7 +1066,7 @@ def sharded_impl(dz, x, scale, amax): amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, @@ -1102,7 +1102,7 @@ def sharded_impl(dz, x, scale, amax): def shardy_sharding_rule( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, @@ -1132,28 +1132,30 @@ def shardy_sharding_rule( ) prefix = "DActLuDBias_" + # get sharding rules base on the input shape scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - value_types[1].shape, unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, + unique_var=prefix, + flatten_axis=-2, + q_layout=quantize_layout, ) - x_axes = scale_rules.input_spec - dz_axes = (*x_axes[:-2], x_axes[-1]) - out = x_axes - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "scale_inv_colwise",) - if is_2x: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) - else: - colwise_out = out - colwise_scale_inv = scale_rules.colwise_rule - - dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) - amax = (prefix + "amax",) + input_spec = scale_rules.input_spec + dz_spec = (*input_spec[:-2], input_spec[-1]) + dbias = input_spec[-2:] if is_dbias else (prefix + "_dbias",) + amax = (prefix + "_amax",) + scale = (prefix + "_scale",) return SdyShardingRule( - (dz_axes, x_axes, ("…2",), amax), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + (tuple(dz_spec), tuple(input_spec), scale, amax), + ( + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, + dbias, + ), **scale_rules.factor_sizes, ) @@ -1269,7 +1271,7 @@ def act_lu( return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( @@ -1298,7 +1300,7 @@ def act_lu( act_enum=act_type_id, act_len=act_len, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, act_params=act_params, amax_scope=amax_scope, @@ -1354,7 +1356,7 @@ def act_lu( act_enum=act_type_id, act_len=act_len, scaling_mode=quantizer.scaling_mode.value, - is_2x=quantizer.is_2x2x(), + quantize_layout=quantizer.q_layout, scale_dtype=quantizer.get_scale_dtype(), act_params=act_params, amax_scope=amax_scope, @@ -1415,7 +1417,7 @@ def quantize_dact_dbias( act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive if not PrimitiveClass.enabled() or ( - quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE + quantizer is not None and quantizer.q_layout.is_colwise_only ): return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) if quantizer is None: @@ -1428,7 +1430,7 @@ def quantize_dact_dbias( out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused + quantize_layout=QuantizeLayout.ROWWISE, # unused scale_dtype=jnp.float32, # unused is_dbias=False, act_enum=act_type_id, @@ -1555,7 +1557,7 @@ def quantize_dact_dbias( amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=quantizer.is_2x2x(), + quantize_layout=quantizer.q_layout, scale_dtype=quantizer.get_scale_dtype(), is_dbias=is_dbias, act_enum=act_type_id, @@ -1568,7 +1570,7 @@ def quantize_dact_dbias( ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise: colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 572d82f18..f15fe72ba 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -207,7 +207,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant break # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, # but this fails when bias fusion is turned on with arch < 100. - force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() + force_1x_quantization = ( + quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise + ) return ( (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 @@ -229,7 +231,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, @return: the output of 'f' with the colwise output calculated """ should_apply_war = ( - quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() + quantizer is not None + and quantizer.scaling_mode.is_tensor_scaling() + and quantizer.q_layout.is_rowwise_colwise ) if not should_apply_war: return None diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index d09ce7ef7..92efb91a7 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -11,7 +11,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec @@ -112,7 +112,7 @@ def abstract( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -148,6 +148,13 @@ def abstract( "Current tensor scaling is not supported for fused norm and quantization. Please do" " norm in higher-precision then quantize with current tensor scaling." ) + assert not ScalingMode(scaling_mode).is_nvfp4_scaling, ( + "NVFP4 block scaling is not yet supported for fused norm and quantization." + " Please do norm in higher-precision then quantize with current tensor scaling." + ) + assert ( + not quantize_layout.is_colwise_only + ), "Fused norm with colwise-only quantization is not supported." mu_rsigama_dtype = jnp.float32 @@ -165,7 +172,7 @@ def abstract( updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -173,7 +180,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,) colwise_scale_inv_aval = jax.core.ShapedArray( shape=colwise_scale_inv_shape, dtype=scale_dtype ) @@ -189,7 +196,7 @@ def abstract( zero_centered_gamma, epsilon, get_forward_sm_margin(), - is_2x, + True, # is_training ) wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -245,7 +252,7 @@ def lowering( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -287,7 +294,7 @@ def lowering( epsilon=epsilon, sm_margin=sm_margin, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, output_amax_when_no_scaling=output_amax_when_no_scaling, ) @@ -303,7 +310,7 @@ def impl( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -335,7 +342,7 @@ def impl( epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -349,7 +356,7 @@ def impl( scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.has_colwise: colwise_scale_inv = colwise_scale_inv.flatten()[ : reduce(operator.mul, colwise_scale_inv_shape, 1) ].reshape(colwise_scale_inv_shape) @@ -373,7 +380,7 @@ def batcher( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -409,7 +416,7 @@ def batcher( epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -426,7 +433,7 @@ def infer_sharding_from_operands( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -450,7 +457,7 @@ def infer_sharding_from_operands( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) @@ -488,7 +495,7 @@ def partition( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -524,7 +531,7 @@ def partition( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) @@ -586,7 +593,7 @@ def sharded_impl(x, scale, amax, gamma, beta): epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -623,7 +630,7 @@ def shardy_sharding_rule( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, amax_scope, transpose_batch_sequence, @@ -646,25 +653,29 @@ def shardy_sharding_rule( result_types, ) - prefix = "NormFwd_" + prefix = "NormFwd" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - value_types[0].shape, unique_var=prefix + "x", flatten_axis=-1 + value_types[0].shape, + unique_var=prefix, + flatten_axis=-1, + q_layout=quantize_layout, ) - x_axes = scale_rules.input_spec + input_spec = scale_rules.input_spec - out = x_axes - colwise_out = out if is_2x else (prefix + "out_colwise",) - rsigma = x_axes[:-1] - mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma - amax = (prefix + "amax",) + rsigma = input_spec[:-1] + mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) + gamma = (BATCHING + prefix + "_gamma",) + beta = (BATCHING + prefix + "_beta",) return SdyShardingRule( - (x_axes, ("…1",), amax, ("…2",), ("…3",)), + (input_spec, scale, amax, gamma, beta), ( - out, - colwise_out, - scale_rules.rowwise_rule, - scale_rules.colwise_rule, + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, amax, mu, rsigma, @@ -987,7 +998,7 @@ def layernorm_fwd( return (output, mu, rsigma) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1008,7 +1019,7 @@ def layernorm_fwd( epsilon=epsilon, out_dtype=x.dtype, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, amax_scope=amax_scope, transpose_batch_sequence=False, @@ -1067,10 +1078,11 @@ def layernorm_fwd( ) return out, mu, rsigma - is_2x2x = quantizer.is_2x2x() - # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): - is_2x2x = False + # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose + q_layout = quantizer.q_layout + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -1090,7 +1102,7 @@ def layernorm_fwd( epsilon=epsilon, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=is_2x2x, + quantize_layout=q_layout, scale_dtype=quantizer.get_scale_dtype(), amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -1099,8 +1111,7 @@ def layernorm_fwd( ) quantizer.update(updated_amax) - # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1238,7 +1249,7 @@ def rmsnorm_fwd( return (output, rsigma) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1261,7 +1272,7 @@ def rmsnorm_fwd( epsilon=epsilon, out_dtype=x.dtype, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -1321,10 +1332,11 @@ def rmsnorm_fwd( ) return out, rsigma - is_2x2x = quantizer.is_2x2x() - # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): - is_2x2x = False + # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose + q_layout = quantizer.q_layout + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -1344,7 +1356,7 @@ def rmsnorm_fwd( epsilon=epsilon, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=is_2x2x, + quantize_layout=q_layout, scale_dtype=quantizer.get_scale_dtype(), amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, @@ -1353,8 +1365,7 @@ def rmsnorm_fwd( ) quantizer.update(updated_amax) - # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 67c505bc9..a0e1a6406 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -11,7 +11,7 @@ import jax import jax.numpy as jnp from jax import dtypes, ffi -from jax.experimental.custom_partitioning import SdyShardingRule +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.sharding import PartitionSpec import transformer_engine_jax @@ -122,7 +122,7 @@ def abstract( f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" ) - if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if QuantizeLayout(q_layout).has_rowwise: rowwise_out_shape = out_shape else: rowwise_out_shape = (1,) @@ -170,7 +170,7 @@ def abstract( broadcast_2d_scale_shape_to_1d=True, ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if QuantizeLayout(q_layout).has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: @@ -194,9 +194,7 @@ def abstract( jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(scale_dtype), scaling_mode, - QuantizeLayout( - q_layout - ), # For now until we have auto-decoding for QuantizeLayout enum + q_layout.value, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -272,7 +270,7 @@ def lowering( post_rht_amax, rht_matrix, scaling_mode=scaling_mode.value, - q_layout=q_layout, + q_layout=q_layout.value.value, flatten_axis=flatten_axis, is_dbias=is_dbias, stochastic_rounding=stochastic_rounding, @@ -335,7 +333,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -424,7 +422,7 @@ def infer_sharding_from_operands( PartitionSpec(*x_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: @@ -448,7 +446,7 @@ def infer_sharding_from_operands( if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ( ScalingMode(scaling_mode).is_block_scaling and ScalingMode(scaling_mode).is_colwise_transposed @@ -505,7 +503,7 @@ def partition( desc="BaseDBiasQuantizePrimitive.out_sharding", ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: @@ -529,7 +527,7 @@ def partition( if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: if ( ScalingMode(scaling_mode).is_block_scaling and ScalingMode(scaling_mode).is_colwise_transposed @@ -643,39 +641,37 @@ def shardy_sharding_rule( result_types, ) - prefix = "DBiasQuantize_" + prefix = "DBiasQuantize" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( value_types[0].shape, - unique_var=prefix + "x", + unique_var=prefix, flatten_axis=flatten_axis, + q_layout=q_layout, broadcast_2d_scale_shape_to_1d=True, ) - x_axes = scale_rules.input_spec - - out = x_axes - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "colwise_scale_inv",) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv = scale_rules.colwise_rule - if ScalingMode(scaling_mode).is_colwise_transposed: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) - colwise_scale_inv = tuple( - multidim_transpose(colwise_scale_inv, transpose_axis=flatten_axis) - ) - else: - colwise_out = x_axes - - dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) - amax = (prefix + "amax",) - sr_rng_state = (prefix + "sr_rng_state_partition_axis", prefix + "sr_rng_state_data_axis") + input_spec = scale_rules.input_spec + dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",) + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) + sr_rng_state = ( + BATCHING + prefix + "_sr_rng_state_partition_axis", + BATCHING + prefix + "sr_rng_state_data_axis", + ) - post_rht_amax = (prefix + "post_rht_amax",) - rht_matrix = (prefix + "rht_matrix_1", prefix + "rht_matrix_2") + post_rht_amax = (BATCHING + prefix + "_post_rht_amax",) + rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2") return SdyShardingRule( - (x_axes, ("…1",), amax, sr_rng_state, post_rht_amax, rht_matrix), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix), + ( + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, + dbias, + ), **scale_rules.factor_sizes, ) @@ -762,7 +758,7 @@ def _quantize_dbias_impl( # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - is_unsupported = quantizer.q_layout == QuantizeLayout.COLWISE and not ( + is_unsupported = quantizer.q_layout.is_colwise_only and not ( quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING and hasattr(quantizer, "use_rht") and quantizer.use_rht @@ -845,7 +841,7 @@ def _quantize_dbias_impl( is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() - and quantizer.is_2x2x() + and quantizer.q_layout.is_rowwise_colwise and is_1x_kernel_supported ) q_layout = quantizer.q_layout @@ -879,7 +875,7 @@ def _quantize_dbias_impl( rht_matrix, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=q_layout.value, + q_layout=q_layout, flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, @@ -888,10 +884,10 @@ def _quantize_dbias_impl( use_rht=use_rht, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise: colwise_scale_inv = rowwise_scale_inv - if q_layout == QuantizeLayout.ROWWISE: + if q_layout.is_rowwise_only: # Quantizer requires 2x quantization, but we are using 1x quantization # for performance reasons, so we need to generate the colwise data in JAX if flatten_axis < 0: @@ -1043,7 +1039,7 @@ def abstract( flatten_axis=flatten_axis, ) - if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_rowwise: rowwise_out_shape = out_shape else: rowwise_out_shape = (1,) @@ -1052,7 +1048,7 @@ def abstract( amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: colwise_out_shape = out_shape else: colwise_out_shape = (1,) @@ -1117,7 +1113,7 @@ def lowering( scale, group_sizes, scaling_mode=scaling_mode.value, - q_layout=q_layout, + q_layout=q_layout.value.value, flatten_axis=flatten_axis, ) @@ -1240,7 +1236,7 @@ def grouped_quantize( ) # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet # So we performance ROWWISE_COLWISE and use the colwise_tensor_output - apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE + apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout ( rowwise_casted_output, @@ -1254,7 +1250,7 @@ def grouped_quantize( group_sizes, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=q_layout.value, + q_layout=q_layout, flatten_axis=flatten_axis, group_axis=group_axis, scale_dtype=quantizer.get_scale_dtype(), @@ -1262,7 +1258,7 @@ def grouped_quantize( # For DelayedScaling2x and CurrentScaling2x, the scale buffer # is shared between rowwise and colwise - if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war: + if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war: colwise_scale_inv = rowwise_scale_inv # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead? diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 87c6fa91c..c1c7e0d66 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -57,7 +57,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - JAXX_Scaling_Mode scaling_mode, bool is_2x); + JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout); // Normalization XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); @@ -87,7 +88,7 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, - QuantizeLayout q_layout); + JAXX_Quantize_Layout quantize_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); @@ -162,5 +163,6 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout); #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index f512321c3..34ce29ae1 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -18,7 +18,8 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int, ActivationConfig act_params, bool output_amax_when_no_scaling) { + JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params, + bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; @@ -40,7 +41,6 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto input_shape = std::vector{m, static_cast(act_len * n)}; @@ -77,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; @@ -158,7 +158,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("act_params") .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); @@ -167,11 +167,12 @@ Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - int64_t act_enum, JAXX_Scaling_Mode scaling_mode, bool is_2x_int, - ActivationConfig act_params, bool output_amax_when_no_scaling) { + int64_t act_enum, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params, + bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, - updated_amax_buf, act_enum, scaling_mode, is_2x_int, act_params, + updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params, output_amax_when_no_scaling); } @@ -188,13 +189,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("act_params") .Attr("output_amax_when_no_scaling")); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - JAXX_Scaling_Mode scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -226,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid std::vector{1}); } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); @@ -260,9 +262,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t act_enum, bool is_2x, - bool is_dbias, ActivationConfig act_params, - bool output_amax_when_no_scaling) { + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, + JAXX_Quantize_Layout quantize_layout, bool is_dbias, + ActivationConfig act_params, bool output_amax_when_no_scaling) { // parameters for clamped swiglu used in GPT OSS auto swiglu_limit = act_params.clamped_swiglu.limit; auto swiglu_alpha = act_params.clamped_swiglu.alpha; @@ -340,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; @@ -370,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && + is_quantize_2x2x(quantize_layout) && act_len == 2), "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { @@ -465,7 +468,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("is_dbias") .Attr("act_params") .Attr("output_amax_when_no_scaling"), @@ -476,13 +479,13 @@ Error_Type DActLuDBiasQuantizeInitializeFFI( Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias, ActivationConfig act_params, - bool output_amax_when_no_scaling) { + int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias, + ActivationConfig act_params, bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, - workspace_buf, scaling_mode, act_enum, is_2x, is_dbias, act_params, - output_amax_when_no_scaling); + workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias, + act_params, output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, @@ -502,7 +505,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("is_dbias") .Attr("act_params") .Attr("output_amax_when_no_scaling")); diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 07e9aec7e..21b50c1af 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -34,12 +34,24 @@ inline size_t product(const std::vector &shape) { return ret; } -enum class QuantizeLayout { +enum class JAXX_Quantize_Layout : int64_t { ROWWISE, COLWISE, ROWWISE_COLWISE, }; +inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + +inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + +inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + enum class JAXX_Scaling_Mode : int64_t { NO_SCALING = 0, DELAYED_TENSOR_SCALING = 1, diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index 378e009c8..b01e23c12 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -66,7 +66,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, - bool is_2x, bool output_amax_when_no_scaling) { + JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -86,7 +86,6 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto _norm_type = static_cast(norm_type); - auto _is_2x = static_cast(is_2x); auto x_size = product(x_buf.dimensions()); auto gamma_size = product(gamma_buf.dimensions()); @@ -134,7 +133,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); } - if (_is_2x) { + if (is_quantize_2x2x(quantize_layout)) { output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(), static_cast(out_dtype), input_shape); output_tensor.set_columnwise_scale_inv( @@ -185,25 +184,23 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); -Error_Type NormForwardInitializeFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type amax_buf, Buffer_Type gamma_buf, - Buffer_Type beta_buf, Result_Type output_buf, - Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - Result_Type mu_buf, Result_Type rsigma_buf, - Result_Type wkspace_buf, int norm_type, - bool zero_centered_gamma, double epsilon, int64_t sm_margin, - JAXX_Scaling_Mode scaling_mode, bool is_2x, - bool output_amax_when_no_scaling) { +Error_Type NormForwardInitializeFFI( + cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf, + Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, + bool output_amax_when_no_scaling) { return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, - scaling_mode, is_2x, output_amax_when_no_scaling); + scaling_mode, quantize_layout, output_amax_when_no_scaling); } XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, @@ -227,7 +224,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializ .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x") + .Attr("quantize_layout") .Attr("output_amax_when_no_scaling")); pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index d740df0e2..e57d07872 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -176,11 +176,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) .export_values(); - pybind11::enum_(m, "QuantizeLayout", - pybind11::module_local()) - .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) - .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) + pybind11::enum_(m, "JAXX_Quantize_Layout", pybind11::module_local()) + .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE) + .value("COLWISE", JAXX_Quantize_Layout::COLWISE) + .value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE) .export_values(); pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index a45a69882..1f7db8438 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -20,7 +20,7 @@ namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, - QuantizeLayout q_layout) { + JAXX_Quantize_Layout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -42,7 +42,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); auto scale_shape = std::vector{1}; // Only the pointers will be checked for scale_inv, thus the shapes do not matter - if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + if (is_quantize_rowwise(q_layout)) { output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { if (is_nvfp4) @@ -52,7 +52,7 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } } - if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + if (is_quantize_colwise(q_layout)) { auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); @@ -90,8 +90,8 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis, - bool stochastic_rounding, bool use_rht) { + JAXX_Quantize_Layout quantize_layout, bool is_dbias, + int64_t flatten_axis, bool stochastic_rounding, bool use_rht) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -101,8 +101,6 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T auto *input = input_buf.untyped_data(); - auto const quantize_layout = static_cast(quantize_layout_enum); - auto *output = output_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data(); @@ -127,15 +125,13 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - bool const is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); - if (quantize_layout == QuantizeLayout::ROWWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_rowwise(quantize_layout)) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); if (is_tensor_scaling) { @@ -180,10 +176,9 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T quant_config.set_rng_state(sr_rng_state_tensor.data()); } - if (quantize_layout == QuantizeLayout::COLWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_colwise(quantize_layout)) { if (is_nvfp4 && use_rht) { - if (quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_2x2x(quantize_layout)) { // Do regular rowwise quantization without RHT nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); } @@ -281,7 +276,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") - .Attr("q_layout") + .Attr("q_layout") .Attr("is_dbias") .Attr("flatten_axis") .Attr("stochastic_rounding") @@ -323,7 +318,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, Result_Type colwise_scale_invs, Result_Type amaxs, - JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -336,7 +331,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type()); auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type()); auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type()); - auto const quantize_layout = static_cast(quantize_layout_enum); auto *input_ptr = reinterpret_cast(inputs.untyped_data()); auto *scale_ptr = reinterpret_cast(scales.untyped_data()); @@ -346,10 +340,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto *colwise_sinv_ptr = reinterpret_cast(colwise_scale_invs->untyped_data()); auto *amax_ptr = reinterpret_cast(amaxs->untyped_data()); - bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE; - bool has_colwise = quantize_layout == QuantizeLayout::COLWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; @@ -359,8 +349,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype); size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype); - size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0; - size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0; + size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0; + size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0; size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0; size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0; @@ -423,7 +413,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto inp_i = TensorWrapper(static_cast(input_ptr), shape_i, in_dtype); auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - if (has_rowwise) { + if (is_quantize_rowwise(quantize_layout)) { out_i.set_rowwise_data(static_cast(output_ptr), out_dtype, shape_i); if (is_fp8_dtype(out_dtype)) { @@ -442,7 +432,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty } } - if (has_colwise) { + if (is_quantize_colwise(quantize_layout)) { auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i; out_i.set_columnwise_data(static_cast(colwise_output_ptr), out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling @@ -501,7 +491,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("scaling_mode") - .Attr("q_layout") + .Attr("q_layout") .Attr("flatten_axis")); } // namespace jax diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index 9616965c7..878067a78 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -17,3 +17,4 @@ from .hadamard import * from .helper import * from .device_utils import * +from .misc import * diff --git a/transformer_engine/jax/quantize/misc.py b/transformer_engine/jax/quantize/misc.py new file mode 100644 index 000000000..c1e169d00 --- /dev/null +++ b/transformer_engine/jax/quantize/misc.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +This module provides additional enum and utilities for quantizing tensors in JAX. +""" +from dataclasses import dataclass +from enum import Enum + +from transformer_engine_jax import JAXX_Quantize_Layout + +__all__ = [ + "QuantizeLayout", +] + + +@dataclass(frozen=True) +class QuantizeLayout(Enum): + "Wrapper for JAXX_Quantize_Layout" + + ROWWISE = JAXX_Quantize_Layout.ROWWISE + COLWISE = JAXX_Quantize_Layout.COLWISE + ROWWISE_COLWISE = JAXX_Quantize_Layout.ROWWISE_COLWISE + + @property + def has_rowwise(self) -> bool: + """If the layout has the rowwise component""" + return self.value in (JAXX_Quantize_Layout.ROWWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE) + + @property + def has_colwise(self) -> bool: + """If the layout has the colwise component""" + return self.value in (JAXX_Quantize_Layout.COLWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE) + + @property + def is_rowwise_colwise(self) -> bool: + """If layout is both rowwise and colwise""" + return self.value == JAXX_Quantize_Layout.ROWWISE_COLWISE + + @property + def is_rowwise_only(self) -> bool: + """If layout is rowwise only""" + return self.value == JAXX_Quantize_Layout.ROWWISE + + @property + def is_colwise_only(self) -> bool: + """If layout is colwise only""" + return self.value == JAXX_Quantize_Layout.COLWISE + + def __eq__(self, other): + """Compare this quantize layout with another. + + Args: + other: The other quantize layout to compare with + + Returns: + True if the modes are equal, False otherwise + """ + if not isinstance(other, QuantizeLayout): + return False + return self.value == other.value diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index eb2b7b592..8a54f0b1d 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -15,10 +15,10 @@ import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeLayout from transformer_engine.common import recipe from .scaling_modes import ScalingMode +from .misc import QuantizeLayout from .hadamard import apply_rht from .tensor import ( ScaledTensor, @@ -37,7 +37,6 @@ from ..sharding import get_num_devices_in_mesh __all__ = [ - "QuantizeLayout", "Quantizer", "QuantizerSet", "CurrentScaleQuantizer", @@ -118,14 +117,6 @@ def update(self, *args, **kwargs): """Update quantizer state (no-op in base class).""" del args, kwargs - def is_2x2x(self) -> bool: - """Check if quantizer uses both row-wise and column-wise quantization. - - Returns: - True if using both row-wise and column-wise quantization - """ - return self.q_layout == QuantizeLayout.ROWWISE_COLWISE - def get_data_layout(self) -> str: """Get the data data_layout string. @@ -135,11 +126,11 @@ def get_data_layout(self) -> str: Raises: ValueError: If quantization axis is invalid """ - if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: + if self.q_layout.is_rowwise_colwise: return self.data_layout - if self.q_layout == QuantizeLayout.ROWWISE: + if self.q_layout.is_rowwise_only: return self.data_layout[0] - if self.q_layout == QuantizeLayout.COLWISE: + if self.q_layout.is_colwise_only: return self.data_layout[1] raise ValueError(f"Invalid q_layout: {self.q_layout}") @@ -174,18 +165,10 @@ def quantize( """ del kwargs - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise - if (is_rowwise and is_colwise) or self.is_2x2x(): + if is_rowwise and is_colwise: rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis @@ -299,16 +282,8 @@ def quantize( flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = None @@ -974,16 +949,8 @@ def quantize( flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise assert is_rowwise or is_colwise, "No quantization layout is specified" original_shape = x.shape diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index d490e0275..eea27a35d 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -21,7 +21,8 @@ from jax.tree_util import register_pytree_node_class import jax.numpy as jnp -from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout +from transformer_engine_jax import JAXX_Scaling_Mode +from .misc import QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -72,16 +73,18 @@ class QuantizeShardyRules: Attributes: input_spec: Specification for the input axes - rowwise_rule: Sharding rule for the row-wise scale tensor, depends on - the axes in `input_spec` - colwise_rule: Likewise for the column-wise scale tensor. - factor_sizes: For block scaling, contains the block size factor, which is - used in `input_spec`. + rowwise_out_spec: Sharding spec for the rowwise quantized data + rowwise_scale_spec: Sharding spec for the rowwise scale + colwise_out_spec: Sharding spec for the colwise quantized data + colwise_scale_spec: Sharding spec for the colwise scale + factor_sizes: For block scaling, contains the block size factor """ input_spec: Tuple[str] - rowwise_rule: Tuple[str] - colwise_rule: Tuple[str] + rowwise_out_spec: Tuple[str] + rowwise_scale_spec: Tuple[str] + colwise_out_spec: Tuple[str] + colwise_scale_spec: Tuple[str] factor_sizes: Dict[str, int] @@ -166,7 +169,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -174,7 +179,9 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + q_layout: The layout of the quantized tensor broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. + is_colwise_transposed: Whether the column-wise tensors are transposed. Returns: The Shardy rules for the scaling mode @@ -268,7 +275,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -281,10 +290,17 @@ def get_shardy_sharding_rules( Returns: The Shardy rules for the scaling mode """ - del flatten_axis, broadcast_2d_scale_shape_to_1d - input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) - scale_var = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + del broadcast_2d_scale_shape_to_1d + input_spec = tuple(f"{unique_var}_x_{i}" for i in range(len(input_shape))) + output_spec = tuple(input_spec) + return QuantizeShardyRules( + input_spec, + output_spec, + (BATCHING + f"{unique_var}_scale",), + (BATCHING + f"{unique_var}_colwise_output",), + (BATCHING + f"{unique_var}_colwise_scale",), + {}, + ) class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): @@ -376,7 +392,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -385,14 +403,26 @@ def get_shardy_sharding_rules( unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. - + q_layout: The layout of the quantized tensor + is_colwise_transposed: Whether the colwise scaling is transposed Returns: The Shardy rules for the scaling mode """ - del flatten_axis, broadcast_2d_scale_shape_to_1d - input_spec = tuple(f"{unique_var}{i}" for i in range(len(input_shape))) - scale_var = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + del broadcast_2d_scale_shape_to_1d + input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape))) + output_spec = input_spec + colwise_output_spec = (BATCHING + f"{unique_var}_colwise_output",) + + if q_layout.has_colwise: + from ..cpp_extensions.misc import multidim_transpose + + colwise_output_spec = input_spec + if is_colwise_transposed: + colwise_output_spec = multidim_transpose( + colwise_output_spec, transpose_axis=flatten_axis + ) + scale = (BATCHING + unique_var + "_scale_inv",) + return QuantizeShardyRules(input_spec, output_spec, scale, colwise_output_spec, scale, {}) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -658,7 +688,9 @@ def get_shardy_sharding_rules( input_shape, unique_var, flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. @@ -666,15 +698,18 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization + q_layout: The layout of the quantized tensor broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. - + is_colwise_transposed: Whether the column-wise tensors are transposed. Returns: The Shardy rules for the scaling mode """ - # TODO(Phuong): to rework the shardy rule to handle transposes after NVFP4 is upstreamed + is_rowwise = q_layout.has_rowwise + is_colwise = q_layout.has_colwise + input_rank = len(input_shape) - input_spec = [f"{unique_var}_{i}" for i in range(input_rank)] flatten_axis = (flatten_axis + input_rank) % input_rank + input_spec = [f"{unique_var}_x_{i}" for i in range(input_rank)] assert ( self._block_dims[1] != 1 @@ -690,30 +725,56 @@ def get_shardy_sharding_rules( # We have to use two different factors in the two CompoundFactors because of Shardy # verifier requirements, even though they are the same. + # No CompoundFactor is needed if the dim has the same size as the blocksize blocksizes = {} - colwise_var = f"{unique_var}_None" rowwise_var = f"{unique_var}_None" - if not input_shape[-1] == block_size_1d: + colwise_var = f"{unique_var}_None" + if is_rowwise and not input_shape[-1] == block_size_1d: rowwise_var = input_spec[-1] + "_compound" input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") blocksizes["blocksize_x"] = block_size_1d - if not input_shape[flatten_axis - 1] == block_size_1d: + if is_colwise and not input_shape[flatten_axis - 1] == block_size_1d: colwise_var = input_spec[flatten_axis - 1] + "_compound" input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") blocksizes["blocksize_y"] = block_size_1d # The rowwise and colwise scale tensors should be sharded the same way as the input. # However, we need to adjust the dimensions where the block scaling factor applies. - rowwise = input_spec.copy() - rowwise[-1] = rowwise_var + if is_rowwise: + rowwise_out = input_spec.copy() + rowwise_scale = input_spec.copy() + rowwise_scale[-1] = rowwise_var + else: + rowwise_out = [ + BATCHING + f"{unique_var}_rowwise_output", + ] + rowwise_scale = [ + BATCHING + f"{unique_var}_rowwise_scale_inv", + ] - colwise = input_spec.copy() - colwise[flatten_axis - 1] = colwise_var + if is_colwise: + colwise_out = input_spec.copy() + colwise_scale = input_spec.copy() + colwise_scale[flatten_axis - 1] = colwise_var + if is_colwise_transposed: + from ..cpp_extensions.misc import multidim_transpose + + colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis) + colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis) + else: + colwise_out = [ + BATCHING + f"{unique_var}_colwise_output", + ] + colwise_scale = [ + BATCHING + f"{unique_var}_colwise_scale_inv", + ] return QuantizeShardyRules( tuple(input_spec), - tuple(rowwise), - tuple(colwise), + tuple(rowwise_out), + tuple(rowwise_scale), + tuple(colwise_out), + tuple(colwise_scale), blocksizes, ) @@ -850,7 +911,8 @@ def get_shardy_sharding_rules( self, input_shape, unique_var, - flatten_axis=-1, + flatten_axis, + q_layout, broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. @@ -859,13 +921,19 @@ def get_shardy_sharding_rules( input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix flatten_axis: Axis along which data can be flattened to 2D for quantization. + q_layout: The layout of the quantized tensor broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The Shardy rules for the scaling mode """ return self._get_impl().get_shardy_sharding_rules( - input_shape, unique_var, flatten_axis, broadcast_2d_scale_shape_to_1d + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + self.is_colwise_transposed, ) def get_grouped_scale_shape_2x( diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 6c358a044..25db84409 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -15,10 +15,10 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap +from .misc import QuantizeLayout from ..sharding import ( with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, ) @@ -128,9 +128,7 @@ def dequantize(self): def get_tensor(self, usage: TensorUsage): """Returns the tensor based on the tensor usage.""" q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) - assert ( - q_layout == QuantizeLayout.ROWWISE - ), "Only ROWWISE layout is supported for NoScaleTensor" + assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor" return self def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): @@ -264,8 +262,8 @@ def dequantize(self): def get_tensor(self, usage: TensorUsage): """Returns the tensor based on the tensor usage.""" q_layout = self.scaling_mode.get_quantize_layout(usage) - colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise - rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise + colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise + rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise if colwise_usage_valid or rowwise_usage_valid: return self @@ -301,16 +299,15 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - # TODO(Phuong): Handle padding !? + if self.scaling_mode.is_block_scaling: # Both MXFP8 and NVFP4 scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv return ScaledTensor1x( data=data, - scale_inv=scale_inv, amax=self.amax, + scale_inv=scale_inv, scaling_mode=self.scaling_mode, dq_dtype=self.dq_dtype, _dq_func=self._dq_func, @@ -467,10 +464,10 @@ def get_tensor(self, usage: TensorUsage): q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage) q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage) - if q_layout_rowwise == QuantizeLayout.ROWWISE: + if q_layout_rowwise.is_rowwise_only: return self.rowwise_tensor - if q_layout_colwise == QuantizeLayout.COLWISE: + if q_layout_colwise.is_colwise_only: return self.colwise_tensor raise ValueError( @@ -548,13 +545,13 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) if group_sizes is not None: - flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" # Handling attrs of transposed tensors - group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis + group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": if original_shape[0] == group_sizes.size: original_shape = ( @@ -587,7 +584,7 @@ def create_1x( ) # Handling attrs of transposed tensors - flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + flatten_axis = (data.ndim + flatten_axis) % data.ndim if data_layout == "T": flatten_axis = data.ndim - flatten_axis @@ -669,7 +666,7 @@ def create_2x( colwise_amax, scaling_mode, dq_dtype, - is_colwise=True, # TODO(Phuong): set this correctly + is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, group_sizes=group_sizes, @@ -721,7 +718,7 @@ def create( """ assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet" - if q_layout == QuantizeLayout.ROWWISE_COLWISE: + if q_layout.is_rowwise_colwise: return ScaledTensorFactory.create_2x( data, scale_inv, @@ -740,15 +737,14 @@ def create( colwise_has_rht_applied=colwise_has_rht_applied, ) - is_colwise = q_layout == QuantizeLayout.COLWISE - if is_colwise: + if q_layout.is_colwise_only: return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, colwise_amax if colwise_amax is not None else amax, scaling_mode, dq_dtype, - is_colwise=is_colwise, + is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, group_sizes=group_sizes, @@ -763,7 +759,7 @@ def create( amax, scaling_mode, dq_dtype, - is_colwise=is_colwise, + is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, group_sizes=group_sizes, From 67d63d02f3efe1b8e0984788cc4e9ebf93bfd703 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Thu, 13 Nov 2025 13:58:32 -0800 Subject: [PATCH 122/141] [JAX] Support for checkpointing quantizations (#2356) * Support for checkpointing quantizations Signed-off-by: Jeremy Berchtold * Add jaxpr test for quant checkpoint name Signed-off-by: Jeremy Berchtold * Revert "Support for checkpointing quantizations" This reverts commit f7b784940369d0da2a77c57fa6ea744e883c5832. Signed-off-by: JAX Toolbox * Checkpoint quantizations Signed-off-by: Jeremy Berchtold * lint Signed-off-by: Jeremy Berchtold * revert other files Signed-off-by: Jeremy Berchtold * move checkpointing to VJPs Signed-off-by: Jeremy Berchtold * fix ci failure Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: JAX Toolbox Co-authored-by: JAX Toolbox --- tests/jax/test_recipe_characteristics.py | 121 +++++++++++++------ transformer_engine/jax/dense.py | 13 +- transformer_engine/jax/flax/module.py | 37 ++++-- transformer_engine/jax/layernorm_dense.py | 4 +- transformer_engine/jax/layernorm_mlp.py | 8 +- transformer_engine/jax/quantize/quantizer.py | 76 ++++++++++-- transformer_engine/jax/quantize/tensor.py | 55 +++++++++ 7 files changed, 252 insertions(+), 62 deletions(-) diff --git a/tests/jax/test_recipe_characteristics.py b/tests/jax/test_recipe_characteristics.py index 33fde7e23..b9c8fd783 100644 --- a/tests/jax/test_recipe_characteristics.py +++ b/tests/jax/test_recipe_characteristics.py @@ -263,23 +263,16 @@ def test_autocast_nvfp4_block_scaling(self): class TestJaxprAndHlo: """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations.""" - @pytest_parametrize_wrapper( - "quantization_recipe", - [ - quantization_recipe - for quantization_recipe in SUPPORTED_RECIPES - if isinstance(quantization_recipe, NVFP4BlockScaling) - ], - ) - def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe): - """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton.""" - + def _generate_jaxpr_for_layernorm_mlp_fwd_bwd(self, quantization_recipe, ln_mlp_kwargs=None): + """Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe.""" + ln_mlp_kwargs = ln_mlp_kwargs or {} with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()): model = te_flax.LayerNormMLP( layernorm_type="rmsnorm", return_layernorm_output=False, intermediate_dropout_rate=0.0, dtype=jnp.bfloat16, + **ln_mlp_kwargs, ) var_collect = model.init( @@ -292,29 +285,83 @@ def loss_fn(x, rngs): x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16) rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)} - jaxpr = jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs) - - rht_amax_eqns = [ - eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper" - ] - - assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}" - - def assert_param(index, tensor_name, expected_value: bool): - if expected_value: - assert rht_amax_eqns[index].params["produce_regular_amax"] == True, ( - f"Expected produce_regular_amax for {tensor_name} to be True, indicating no" - " reuse of amax as this tensor does not have a previous operation to fuse" - " with" - ) - else: - assert rht_amax_eqns[index].params["produce_regular_amax"] == False, ( - f"Expected produce_regular_amax for {tensor_name} to be False, indicating" - " reuse of amax" - ) - - assert_param(0, "fwd ln+q", False) - assert_param(1, "fwd act+q", False) - # No previous op before incoming dgrad in the backward so amax is not reused - assert_param(2, "bwd dgrad", True) - assert_param(3, "bwd dact+q", False) + return jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs) + + @pytest_parametrize_wrapper( + "quantization_recipe", + [ + quantization_recipe + for quantization_recipe in SUPPORTED_RECIPES + if isinstance(quantization_recipe, NVFP4BlockScaling) + ], + ) + def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe): + """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton.""" + + jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe) + + rht_amax_eqns = [ + eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper" + ] + + assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}" + + def assert_param(index, tensor_name, expected_value: bool): + if expected_value: + assert rht_amax_eqns[index].params["produce_regular_amax"] == True, ( + f"Expected produce_regular_amax for {tensor_name} to be True, indicating no" + " reuse of amax as this tensor does not have a previous operation to fuse" + " with" + ) + else: + assert rht_amax_eqns[index].params["produce_regular_amax"] == False, ( + f"Expected produce_regular_amax for {tensor_name} to be False, indicating" + " reuse of amax" + ) + + assert_param(0, "fwd ln+q", False) + assert_param(1, "fwd act+q", False) + # No previous op before incoming dgrad in the backward so amax is not reused + assert_param(2, "bwd dgrad", True) + assert_param(3, "bwd dact+q", False) + + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper( + "quantization_checkpoint_name", + [None, "quantization", "some_arbitrary_user_checkpoint_name"], + ) + def test_recipe_supports_quantization_checkpointing( + self, quantization_recipe, quantization_checkpoint_name + ): + """Tests that all supported quantization recipes correctly use checkpoint_name.""" + + kwargs = { + "quantization_checkpoint_name": quantization_checkpoint_name, + } + jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe, kwargs) + + checkpoint_name_eqns = [ + eqn + for eqn in jaxpr.jaxpr.eqns + if eqn.primitive.name == "name" and eqn.params["name"] == quantization_checkpoint_name + ] + + if quantization_checkpoint_name is None: + assert len(checkpoint_name_eqns) == 0, ( + "Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got" + f" {len(checkpoint_name_eqns)}" + ) + return + + # 12 checkpointed values: + # - Fwd pass: + # - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward + # - Kernel Q -> 3 possible output tensors that will be used in the backward + # - Input Activation+Q -> 3 possible output tensors that will be used in the backward + # - Kernel Q -> 3 possible output tensors that will be used in the backward + expected_checkpoint_eqn_count = 12 + + assert len(checkpoint_name_eqns) == expected_checkpoint_eqn_count, ( + f"Expected {expected_checkpoint_eqn_count} checkpoint_name eqns when" + f" quantization_checkpoint_name is set, got {len(checkpoint_name_eqns)}" + ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 44c73a5b1..c497775e0 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -19,6 +19,7 @@ from .cpp_extensions.amax import AmaxScope from .quantize import ( ScaledTensorFactory, + ScaledTensor, ScalingMode, QuantizeLayout, QuantizerSet, @@ -227,8 +228,8 @@ def _dense_fwd_rule( output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_x.get_tensor(usage=TensorUsage.LHS_TRANS), - casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS), + casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x), + casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel), x.shape, kernel.shape, use_bias, @@ -529,8 +530,12 @@ def _grouped_dense_fwd_rule( ctx = ( group_sizes, - ctx_x, - ctx_kernel, + ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x, + ( + ctx_kernel.checkpoint(quantizer_set.kernel) + if isinstance(ctx_kernel, ScaledTensor) + else ctx_kernel + ), x.shape, kernel.shape, use_bias, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 33ea61098..934af3d18 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -6,7 +6,7 @@ """ from functools import reduce import operator -from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType +from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional import numpy as np import jax.numpy as jnp @@ -345,7 +345,11 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method """ def generate_quantizer_set( - self, postfix: str = "", variable_collection: str = None, fp8_recipe=None + self, + postfix: str = "", + variable_collection: str = None, + quantization_checkpoint_name: Optional[str] = None, + fp8_recipe=None, ): """ Generate a set of FP8 meta for a GEMM. @@ -375,7 +379,9 @@ def generate_quantizer_set( quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) quantizer_set = QuantizerFactory.create_set( - fp8_recipe=fp8_recipe, quantize_meta_set=quantize_meta_set + fp8_recipe=fp8_recipe, + quantize_meta_set=quantize_meta_set, + checkpoint_name=quantization_checkpoint_name, ) return quantizer_set @@ -424,6 +430,8 @@ class DenseGeneral(TransformerEngineBase): The data type used to allocate the initial parameters. transpose_batch_sequence: bool, default = False Indicate whether to transpose the batch and sequence dimensions of the input tensor. + quantization_checkpoint_name: Optional[str], default = None + The name for checkpointing quantizations. """ features: Union[Iterable[int], int] @@ -439,6 +447,7 @@ class DenseGeneral(TransformerEngineBase): dtype: DType = jnp.float32 input_axes: Tuple[str, ...] = () transpose_batch_sequence: bool = False + quantization_checkpoint_name: Optional[str] = None def __post_init__(self): if self.kernel_init is None: @@ -496,7 +505,9 @@ def __call__(self, inputs: Array) -> Array: else: bias = None - quantizer_set = self.generate_quantizer_set() + quantizer_set = self.generate_quantizer_set( + quantization_checkpoint_name=self.quantization_checkpoint_name + ) contract_ind = tuple(range(0, len(axis))) y = dense( inputs, @@ -628,6 +639,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): value or None. When None is set, then no scaling is applied. transpose_batch_sequence: bool, default = False Indicate whether to transpose the batch and sequence dimensions of the input tensor. + quantization_checkpoint_name: Optional[str], default = None + The name for checkpointing quantizations. """ features: Union[Iterable[int], int] @@ -654,6 +667,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None transpose_batch_sequence: bool = False + quantization_checkpoint_name: Optional[str] = None def __post_init__(self): if self.kernel_init is None: @@ -693,7 +707,9 @@ def __call__(self, inputs: Array) -> Array: input_dtype = inputs.dtype ln_output = None - quantizer_set = self.generate_quantizer_set() + quantizer_set = self.generate_quantizer_set( + quantization_checkpoint_name=self.quantization_checkpoint_name + ) fuse_layernorm = ( get_quantize_config().is_fp8_enabled() @@ -941,6 +957,8 @@ class LayerNormMLP(TransformerEngineBase): The data type used to allocate the initial parameters. transpose_batch_sequence: bool, default = False Indicate whether to transpose the batch and sequence dimensions of the input tensor. + quantization_checkpoint_name: Optional[str], default = None + The name for checkpointing quantizations. """ intermediate_dim: int = 2048 @@ -976,6 +994,7 @@ class LayerNormMLP(TransformerEngineBase): ffn1_ckpt_name: str = "ffn1" ffn2_ckpt_name: str = "ffn2" transpose_batch_sequence: bool = False + quantization_checkpoint_name: Optional[str] = None def __post_init__(self): if self.kernel_init is None: @@ -1010,8 +1029,12 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: """ assert self.axis == -1, "Only support axis == -1 at this moment" - ffn1_quantizer_set = self.generate_quantizer_set("_0") - ffn2_quantizer_set = self.generate_quantizer_set("_1") + ffn1_quantizer_set = self.generate_quantizer_set( + "_0", quantization_checkpoint_name=self.quantization_checkpoint_name + ) + ffn2_quantizer_set = self.generate_quantizer_set( + "_1", quantization_checkpoint_name=self.quantization_checkpoint_name + ) input_dtype = inputs.dtype ln_output = None diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 705c74232..b9482b7bd 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -236,8 +236,8 @@ def _layernorm_dense_fwd_rule( output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), - casted_kernel.get_tensor(TensorUsage.RHS_TRANS), + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x), + casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel), x.shape, kernel.shape, mu, diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 100848fdd..2fd0f07d6 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -390,11 +390,11 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), - casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS), + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x), + casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel), dot_1_output, - casted_act_out.get_tensor(TensorUsage.LHS_TRANS), - casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS), + casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x), + casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel), x_contracting_dims, k_contracting_dims, kernel_1.shape, diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 8a54f0b1d..adff31748 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -83,12 +83,15 @@ class Quantizer(ABC): q_dtype: The data type for quantized values scaling_mode: The scaling mode to use for quantization q_layout: The quantization axis (row-wise, column-wise, or both) + data_layout: The data layout string (e.g., "NT") + checkpoint_name: Optional name for checkpointing quantization state """ q_dtype: jnp.dtype scaling_mode: ScalingMode q_layout: QuantizeLayout data_layout: str + checkpoint_name: Optional[str] = None def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -97,7 +100,13 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = () - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + ) return (children, aux_data) @classmethod @@ -337,7 +346,13 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.scale, self.amax_history) - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + ) return (children, aux_data) def _quantize_func( @@ -588,7 +603,14 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.stochastic_rounding_rng_state,) - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.use_rht) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + self.use_rht, + ) return (children, aux_data) @classmethod @@ -867,7 +889,14 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.quantizers,) - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + self.n_groups, + ) return (children, aux_data) def __post_init__(self): @@ -1041,6 +1070,7 @@ def create( q_dtype: jnp.dtype = None, q_layout: QuantizeLayout = None, n_groups: int = None, + checkpoint_name: Optional[str] = None, **kwargs, ) -> Quantizer: """Create one or more quantizers with specified parameters. @@ -1052,6 +1082,7 @@ def create( q_layout: Quantization axis flatten_axis: The quantization axis for the tensor n_groups: Number of quantizers if GroupedQuantizer + checkpoint_name: Optional name for checkpointing quantizations **kwargs: Additional arguments for quantizer initialization Returns: @@ -1075,7 +1106,11 @@ def create( for _ in range(n_quantizers): quantizers.append( quantizer_type( - q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + checkpoint_name=checkpoint_name, + **kwargs, ) ) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) @@ -1089,6 +1124,7 @@ def _create_set( bwd_dtype, is_2x2x, n_groups, + checkpoint_name: Optional[str] = None, **kwargs, ) -> QuantizerSet: """Create a set of quantizers for forward and backward passes. @@ -1101,6 +1137,7 @@ def _create_set( bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization n_groups + checkpoint_name: Optional name for checkpointing quantizations **kwargs: Additional arguments for quantizer initialization Returns: @@ -1123,12 +1160,32 @@ def _create_set( else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) + q_x = QuantizerFactory.create( + 1, + x_scaling_mode, + fwd_dtype, + q_layout_x, + n_groups, + checkpoint_name=checkpoint_name, + **args_x, + ) q_kernel = QuantizerFactory.create( - 1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel + 1, + kernel_scaling_mode, + fwd_dtype, + q_layout_kernel, + n_groups, + checkpoint_name=checkpoint_name, + **args_kernel, ) q_dgrad = QuantizerFactory.create( - 1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad + 1, + grad_scaling_mode, + bwd_dtype, + q_layout_dgrad, + n_groups, + checkpoint_name=checkpoint_name, + **args_grad, ) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @@ -1140,6 +1197,7 @@ def create_set( bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, n_groups: int = None, + checkpoint_name: Optional[str] = None, # TODO(jberchtold): rename fp8_recipe to quantization_recipe fp8_recipe: Optional[recipe.Recipe] = None, **kwargs, @@ -1153,6 +1211,7 @@ def create_set( bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X n_groups: + checkpoint_name: Optional name for checkpointing quantizations fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. **kwargs: Additional arguments for quantizer initialization @@ -1208,6 +1267,7 @@ def create_set( bwd_dtype=bwd_dtype, is_2x2x=is_2x2x, n_groups=n_groups, + checkpoint_name=checkpoint_name, **kwargs, ) ) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 25db84409..90f139c3d 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -14,6 +14,7 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class +from jax.ad_checkpoint import checkpoint_name as jax_checkpoint_name from .scaling_modes import ScalingMode, TensorUsage @@ -89,6 +90,17 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st The tensor with applied sharding constraints """ + @abstractmethod + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + @dataclass class AbstractBaseTensor1x(AbstractBaseTensor): @@ -150,6 +162,18 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st amax=self.amax, ) + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + assert quantizer is None, "NoScaleTensor does not support quantization." + return self + class ScaledTensor(ABC): """Abstract base class for scaled tensors.""" @@ -317,6 +341,20 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st has_rht_applied=self.has_rht_applied, ) + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + if quantizer is None or quantizer.checkpoint_name is None: + return self + + return jax_checkpoint_name(self, name=quantizer.checkpoint_name) + @register_pytree_node_class @dataclass @@ -420,6 +458,20 @@ def tree_flatten(self): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): raise NotImplementedError + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + if quantizer is None or quantizer.checkpoint_name is None: + return self + + return jax_checkpoint_name(self, name=quantizer.checkpoint_name) + @register_pytree_node_class @dataclass @@ -496,6 +548,9 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return ScaledTensor2x(rowwise_tensor, colwise_tensor) + def checkpoint(self, quantizer): + raise NotImplementedError + @dataclass class ScaledTensorFactory: From 0ded11340ba28267d0826fca550e0666ea8a00aa Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 13 Nov 2025 18:40:37 -0500 Subject: [PATCH 123/141] [JAX] XLA_FLAG to WAR the current NCCL issue with test_distributed_softmax.py (#2378) * add war for test_distributed_softmax.py Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- qa/L1_jax_distributed_unittest/test.sh | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index f4ea2dd68..886f27747 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -28,7 +28,9 @@ python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/py python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_mlp.xml $TE_PATH/tests/jax/test_distributed_layernorm_mlp.py || test_fail "test_distributed_layernorm_mlp.py" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_softmax.xml $TE_PATH/tests/jax/test_distributed_softmax.py || test_fail "test_distributed_softmax.py" +# XLA_FLAGS to WAR for test_distributed_softmax issue with NCCL +# TODO(Kshitij): remove when NCCL issue is fixed +XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_softmax.xml $TE_PATH/tests/jax/test_distributed_softmax.py || test_fail "test_distributed_softmax.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" From 262c184eb8331dcb50037477881900e46bd5c5f2 Mon Sep 17 00:00:00 2001 From: Robin Zhang Date: Fri, 14 Nov 2025 23:53:48 +0800 Subject: [PATCH 124/141] [PyTorch] Add reset cudagraph interface (#2367) * reset cudagraph Signed-off-by: Robin Zhang * use closure instead of mutable default values Signed-off-by: Robin Zhang * add test Signed-off-by: Robin Zhang * fix test Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_cuda_graphs.py | 38 +++++++++++++++++++++++++---- transformer_engine/pytorch/graph.py | 27 +++++++++++++------- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index fa8754d60..eacbf5168 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Iterable, List, Union +from typing import Callable, Dict, Iterable, List, Tuple, Union import pytest import torch @@ -160,6 +160,20 @@ def get_outputs( return values +def reset_graphs( + graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]], +) -> None: + """Reset CUDA graphs.""" + if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list): + for callable in graphed_callables: + callable.reset() + elif isinstance(graphed_callables, dict): + for callable in graphed_callables.values(): + callable.reset() + else: + graphed_callables.reset() + + class _Sequential(torch.nn.Sequential): """Sequential model that forwards keyword arguments to modules""" @@ -322,7 +336,12 @@ def _test_cuda_graphs( output.backward(grad_output) optimizer.step() - return get_outputs(model, output) + outputs = get_outputs(model, output) + if graph_mode == "full": + reset_graphs(model) + elif graph_mode == "individual": + reset_graphs(modules) + return outputs @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @@ -468,7 +487,10 @@ def _test_cuda_graphs_with_dot_product_attention( output = model(*inputs) output.backward(grad_output) - return get_outputs(model, output) + outputs = get_outputs(model, output) + if with_graph: + reset_graphs(model) + return outputs @pytest.mark.parametrize("dtype", dtypes) @@ -553,7 +575,10 @@ def _test_cuda_graphs_with_kwargs( output.backward(grad_output) optimizer.step() - return get_outputs(model, output) + outputs = get_outputs(model, output) + if with_graph: + reset_graphs(model) + return outputs def test_make_graphed_callables_with_kwargs( @@ -668,7 +693,10 @@ def backward(layer_idx: int, microbatch_idx: int): optimizer.step() outputs = [y for _, y in sorted(outputs.items())] - return get_outputs(model, outputs) + outputs = get_outputs(model, outputs) + if with_graph: + reset_graphs(layer_forwards) + return outputs def test_make_graphed_callables_with_interleaved_pipeline_parallelism( diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 9af9fb887..f55f1dd12 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -756,6 +756,21 @@ def functionalized(*user_args, **user_kwargs): return functionalized + def make_graphed_attribute_functions(graph_idx): + + # Attach backward_dw as an attribute to the graphed callable. + def backward_dw(): + if need_bwd_dw_graph.get(graph_idx, False): + bwd_dw_graphs[graph_idx].replay() + + # Attach reset as an attribute to the graphed callable. + def reset(): + fwd_graphs[graph_idx].reset() + bwd_graphs[graph_idx].reset() + bwd_dw_graphs[graph_idx].reset() + + return backward_dw, reset + # Put together the final graphed callables ret = [] for i in range(len(sample_args)): @@ -831,15 +846,9 @@ def new_fwd(*user_args, **user_kwargs): else: ret.append(graphed) - # Attach backward_dw as an attribute to the graphed callable. - def backward_dw( - need_backward_dw=need_bwd_dw_graph.get(i, False), - bwd_dw_graph=bwd_dw_graphs[i], - ): - if need_backward_dw: - bwd_dw_graph.replay() - - setattr(ret[-1], "backward_dw", backward_dw) + backward_dw_func, reset_func = make_graphed_attribute_functions(i) + setattr(ret[-1], "backward_dw", backward_dw_func) + setattr(ret[-1], "reset", reset_func) if just_one_callable: return ret[0] From b88f727b44d7779200a7f57c279805930a3883ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 14 Nov 2025 18:26:42 +0100 Subject: [PATCH 125/141] [JAX] Make all jax attention calls use non-packed common calls (#2358) * fix Signed-off-by: Pawel Gadzinski * add notes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * small fixes Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/fused_attn/fused_attn.cpp | 2 +- .../jax/csrc/extensions/attention.cpp | 331 ++++++++---------- 2 files changed, 140 insertions(+), 193 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ac6fefdc6..611beb7b8 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -29,7 +29,7 @@ transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *so return view; } -// Helper function to calculate stride for packed QKV tensor unpacking +// Helper function to calculate stride in bytes for packed QKV tensor unpacking size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, size_t h, size_t d) { size_t stride = 0; diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index a99f4fae9..ac7eba5c8 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -123,17 +123,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; @@ -156,7 +147,6 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -174,37 +164,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, is_training, - false, false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, query_workspace_tensor.data(), - nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, softmax_type, window_size_left, window_size_right, - query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported QKVLayout."); - } + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -291,47 +258,57 @@ static void FusedAttnForwardImpl( /* Call the underlying NVTE API */ auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + + // Prepare Q, K, V pointers and shapes based on layout + // Python passes dummy tensors for unused slots, so we extract from the actual packed data + void *q_ptr = q; + void *k_ptr = k; + void *v_ptr = v; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), q_max_seqlen, is_training, false, - false, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); + // QKV packed in q: [batch*seqlen, 3, heads, dim] + // Python passes: q=packed_qkv, k=dummy, v=dummy + // Extract K and V pointers from the packed q data + NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); + NVTE_CHECK(qk_head_dim == v_head_dim, + "For QKV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim); + q_ptr = q; + k_ptr = static_cast(static_cast(q) + stride); + v_ptr = static_cast(static_cast(q) + 2 * stride); + // For packed QKV, all have same shape since they're views into the same packed tensor + k_shape = q_shape; + v_shape = q_shape; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), dummy_softmax_offset_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, false, false, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); + // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] + // Python passes: q=query, k=packed_kv, v=dummy + // Extract V pointer from the packed k data + NVTE_CHECK(qk_head_dim == v_head_dim, + "For KV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); + q_ptr = q; + k_ptr = k; + v_ptr = static_cast(static_cast(k) + stride); + // V has same shape as K since they're packed together + v_shape = k_shape; } + // else NVTE_HD_HD_HD: pointers and shapes already correct + + auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); + auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); + auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); + + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -414,20 +391,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); @@ -450,7 +416,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -471,42 +436,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked( - qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, - deterministic, false, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_bwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, - window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } + + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -552,76 +493,82 @@ static void FusedAttnBackwardImpl( softmax_aux, rng_state, bias); /* Call the underly NVTE API */ + // Prepare Q, K, V pointers and shapes based on layout + void *q_ptr = q; + void *k_ptr = k; + void *v_ptr = v; + void *dq_ptr = dq; + void *dk_ptr = dk; + void *dv_ptr = dv; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); - if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), q_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - softmax_type, window_size_left, window_size_right, deterministic, - false, workspace_tensor.data(), stream); + // QKV packed in q: [batch*seqlen, 3, heads, dim] + NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); + NVTE_CHECK(qk_head_dim == v_head_dim, + "For QKV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim); + q_ptr = q; + k_ptr = static_cast(static_cast(q) + stride); + v_ptr = static_cast(static_cast(q) + 2 * stride); + dq_ptr = dq; + dk_ptr = static_cast(static_cast(dq) + stride); + dv_ptr = static_cast(static_cast(dq) + 2 * stride); + k_shape = q_shape; + v_shape = q_shape; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); - if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - q_max_seqlen, kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, false, - workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); - if (is_ragged) { - cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); - cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); + // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] + NVTE_CHECK(qk_head_dim == v_head_dim, + "For KV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); + q_ptr = q; + k_ptr = k; + v_ptr = static_cast(static_cast(k) + stride); + dq_ptr = dq; + dk_ptr = dk; + dv_ptr = static_cast(static_cast(dk) + stride); + // V has same shape as K since they're packed together + v_shape = k_shape; + } + + auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); + auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); + auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); + auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype); + auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype); + auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype); + + if (is_ragged) { + size_t dtype_size = typeToSize(dtype); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + // For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once + cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream); + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + // Clear dq + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream); + // For packed KV, dk contains both dk and dv - clear all at once + cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream); + } else { + // All separate - clear each individually + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream); + cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream); } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, softmax_type, window_size_left, window_size_right, deterministic, - false, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); } + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); + nvte_tensor_pack_destroy(&aux_input_tensors); } From a0754757660ac5a747b0be54c5398fca032161aa Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 14 Nov 2025 10:21:03 -0800 Subject: [PATCH 126/141] [JAX] Improve support and testing for direct recipe usage without autocast contexts (#2366) * Refactor to avoid storing a global quantization config so direct recipe passing works as intended Signed-off-by: Jeremy Berchtold * fix use_split_accumulator for current scaling recipe Signed-off-by: Jeremy Berchtold * fix tests that pass direct recipe and were missing quantize meta set Signed-off-by: Jeremy Berchtold * Revert "fix use_split_accumulator for current scaling recipe" This reverts commit a74ab7df812ec0a069b1bdd208debb93ec25a900. Signed-off-by: Jeremy Berchtold * fix ci failures Signed-off-by: Jeremy Berchtold * Fix amax_history post_init Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * Update transformer_engine/jax/quantize/quantizer.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci failures Signed-off-by: Jeremy Berchtold * fix ci issue Signed-off-by: Jeremy Berchtold * address comments Signed-off-by: Jeremy Berchtold * make recipe assertion classes in test_recipe_characteristics not inherit from unittest.TestCase Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_custom_call_compute.py | 19 +- tests/jax/test_layer.py | 21 +- tests/jax/test_recipe_characteristics.py | 338 +++++++++++------- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/cpp_extensions/gemm.py | 23 +- .../jax/cpp_extensions/quantization.py | 4 +- transformer_engine/jax/dense.py | 3 +- transformer_engine/jax/flax/module.py | 34 +- transformer_engine/jax/layernorm_dense.py | 3 +- transformer_engine/jax/layernorm_mlp.py | 3 +- transformer_engine/jax/quantize/helper.py | 47 ++- transformer_engine/jax/quantize/quantizer.py | 89 +++-- 12 files changed, 373 insertions(+), 213 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3d4f179ab..c8bd9d47c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -40,6 +40,8 @@ QuantizerFactory, QuantizeLayout, noop_quantizer_set, + QuantizeMetaSet, + QuantizeMeta, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -1457,7 +1459,12 @@ def ref_func(x, w, bias, data_layout): value_n_grad_primitive_func = value_and_grad(primitive_func, (0, 1, 2)) value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) - quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) + quantizer_set = QuantizerFactory.create_set( + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), + ) n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): @@ -1516,7 +1523,12 @@ def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) - quantizer_set = QuantizerFactory.create_set(fp8_recipe=recipe) + quantizer_set = QuantizerFactory.create_set( + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), + ) if norm_type == "layernorm": beta = jax.random.normal(subkeys[3], (k,)).astype(jnp.bfloat16) @@ -1605,6 +1617,9 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), ) if norm_type == "layernorm": diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d1b2535c4..b51d6b213 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -23,7 +23,8 @@ from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.quantize import ( - get_quantize_config, + get_global_quantize_recipe, + get_quantize_config_with_recipe, ScalingMode, is_fp8_available, update_collections, @@ -358,7 +359,7 @@ def test_backward( ref_params, test_params = self._sync_params(ref_params, test_params) - if get_quantize_config().is_fp8_enabled(): + if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled(): for _ in range(4): _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( inputs, @@ -368,14 +369,24 @@ def test_backward( test_layer, ) if ( - get_quantize_config().get_scaling_mode(TensorSource.X) + get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode( + TensorSource.X + ) == ScalingMode.DELAYED_TENSOR_SCALING ): _, updated_quantize_meta = flax.core.pop( - updated_state[0], get_quantize_config().COLLECTION_NAME + updated_state[0], + get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).COLLECTION_NAME, ) test_others = update_collections( - {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others + { + get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).COLLECTION_NAME: updated_quantize_meta + }, + test_others, ) del updated_quantize_meta del updated_state diff --git a/tests/jax/test_recipe_characteristics.py b/tests/jax/test_recipe_characteristics.py index b9c8fd783..5171a6c62 100644 --- a/tests/jax/test_recipe_characteristics.py +++ b/tests/jax/test_recipe_characteristics.py @@ -4,6 +4,7 @@ import unittest from functools import partial +from abc import ABC, abstractmethod import flax import jax @@ -13,6 +14,7 @@ from utils import assert_allclose, pytest_parametrize_wrapper from transformer_engine.common.recipe import ( + Recipe, DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling, @@ -21,13 +23,13 @@ from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.jax import autocast from transformer_engine.jax.quantize import ( - get_quantize_config, + get_global_quantize_recipe, + get_quantize_config_with_recipe, get_supported_quantization_recipes, is_scaling_mode_supported, ScalingMode, update_collections, TensorSource, - QuantizerFactory, QuantizeLayout, ) from transformer_engine.jax.quantize.helper import _format2dtypes @@ -49,16 +51,17 @@ def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): # Define a function with a custom VJP (vector-Jacobian product) @partial(jax.custom_vjp, nondiff_argnums=(1,)) def quantizer_check(inner_quantizer_set, assertion_func, x): - return quantizer_check_fwd(inner_quantizer_set, assertion_func, x) + return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)[0] def quantizer_check_fwd(inner_quantizer_set, assertion_func, x): assertion_func(inner_quantizer_set.x, TensorSource.X) assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL) assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD) - return x + return x, (inner_quantizer_set,) - def quantizer_check_bwd(ctx, g): - return (g,) + def quantizer_check_bwd(assertion_func, ctx, g): + (inner_quantizer_set,) = ctx + return (inner_quantizer_set, g) quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd) return quantizer_check(outer_quantizer_set, assertion_func, x) @@ -69,10 +72,11 @@ class TestModule(TransformerEngineBase): # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None assertion_func: callable + direct_recipe: Recipe @nn.compact def __call__(self, x): - quantizer_set = self.generate_quantizer_set() + quantizer_set = self.generate_quantizer_set(fp8_recipe=self.direct_recipe) return quantizer_check_vjp(quantizer_set, self.assertion_func, x) @@ -97,167 +101,239 @@ def test_update_collections(self): self.assertEqual(updated_state["test2"], original_val) -class TestFP8Functions(unittest.TestCase): +def assert_fp8_format(quantizer, tensor_source, fp8_format): + if fp8_format == FP8Format.HYBRID: + if tensor_source == TensorSource.DGRAD: + assert quantizer.q_dtype == jnp.float8_e5m2 + else: + assert quantizer.q_dtype == jnp.float8_e4m3fn + elif fp8_format == FP8Format.E4M3: + assert quantizer.q_dtype == jnp.float8_e4m3fn + else: + raise ValueError(f"Unsupported FP8 format: {fp8_format}") - def _check_default_state(self): - self.assertFalse(get_quantize_config().is_fp8_enabled()) - - def _compare_delay_scaling(self, test): - self.assertEqual(get_quantize_config().MARGIN, test.margin) - self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) - self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) - self.assertEqual(get_quantize_config().AMAX_HISTORY_LEN, test.amax_history_len) - self.assertEqual(get_quantize_config().AMAX_COMPUTE_ALGO.value, test.amax_compute_algo) - - def _compare_current_scaling(self, test): - self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) - self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) - for tensor_source in TensorSource: - self.assertEqual( - get_quantize_config().get_scaling_mode(tensor_source), - ScalingMode.CURRENT_TENSOR_SCALING, - ) - def _compare_mxfp8_scaling(self, test): - self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp8_format)[0]) - self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp8_format)[1]) - for tensor_source in TensorSource: - self.assertEqual( - get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING - ) +class RecipeAssertionBase(ABC): + """Base class for defining recipe assertions.""" - def _compare_nvfp4_scaling(self, test): - self.assertEqual(get_quantize_config().FWD_DTYPE, _format2dtypes(test.fp4_format)[0]) - self.assertEqual(get_quantize_config().BWD_DTYPE, _format2dtypes(test.fp4_format)[1]) - for tensor_source in TensorSource: - target_scaling_mode = ( - ScalingMode.NVFP4_2D_SCALING - if (not test.disable_2d_quantization) and tensor_source == TensorSource.KERNEL - else ScalingMode.NVFP4_1D_SCALING - ) - self.assertEqual( - get_quantize_config().get_scaling_mode(tensor_source), target_scaling_mode - ) - self.assertEqual( - get_quantize_config().DISABLE_STOCHASTIC_ROUNDING, test.disable_stochastic_rounding - ) - self.assertEqual(get_quantize_config().DISABLE_RHT, test.disable_rht) - self.assertEqual( - get_quantize_config().DISABLE_2D_QUANTIZATION, test.disable_2d_quantization - ) + @abstractmethod + def assert_context(self, ref_recipe, quantize_config): + """Asserts that the quantize_config matches the expected properties from the reference recipe when the recipe is used with an autocast context. - def _compare_nvfp4_scaling_quantizers(self, test): - """Check that the quantizers created have the expected stochastic rounding state and the state is preserved across VJP boundaries.""" + Args: + ref_recipe: The reference quantization recipe. + quantize_config: The quantization configuration to be checked. + """ + pass - def assertion_func(quantizer, tensor_source): - if test.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD: - self.assertIsNone(quantizer.stochastic_rounding_rng_state) - else: - self.assertIsNotNone(quantizer.stochastic_rounding_rng_state) + @abstractmethod + def assert_quantizers(self, ref_recipe, quantizer, tensor_source): + """Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction. + + Args: + ref_recipe: The reference quantization recipe. + quantizer: The quantizer to be checked. + tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD). + """ + pass - expected_rht = ( - quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING - and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE} - and not test.disable_rht + +class DelayedScalingRecipeAssertion(RecipeAssertionBase): + + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.MARGIN == ref_recipe.margin + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1] + assert quantize_config.AMAX_HISTORY_LEN == ref_recipe.amax_history_len + assert quantize_config.AMAX_COMPUTE_ALGO.value == ref_recipe.amax_compute_algo + for tensor_source in TensorSource: + assert ( + quantize_config.get_scaling_mode(tensor_source) + == ScalingMode.DELAYED_TENSOR_SCALING ) - self.assertEqual(quantizer.use_rht, expected_rht) - x = jnp.ones((), dtype=jnp.float32) - test_module = TestModule(assertion_func=assertion_func) - param_key, sr_key = jax.random.split(jax.random.PRNGKey(0)) - rngs = {"params": param_key, "sr_rng": sr_key} - variables = test_module.init(rngs, x) + def assert_quantizers(self, ref_recipe: DelayedScaling, quantizer, tensor_source): + assert quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING + assert quantizer.margin == ref_recipe.margin + assert quantizer.amax_compute_algo.value == ref_recipe.amax_compute_algo + assert quantizer.amax_history.shape == (ref_recipe.amax_history_len,) + assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format) - jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs) - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_autocast_delayed_scaling(self): - self._check_default_state() +class CurrentScalingRecipeAssertion(RecipeAssertionBase): - with autocast(enabled=False, recipe=DelayedScaling(), mesh_resource=MeshResource()): - self._check_default_state() + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1] + for tensor_source in TensorSource: + assert ( + quantize_config.get_scaling_mode(tensor_source) + == ScalingMode.CURRENT_TENSOR_SCALING + ) - self._check_default_state() + def assert_quantizers(self, ref_recipe: Float8CurrentScaling, quantizer, tensor_source): + assert quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format) - ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(ds) - self._check_default_state() +class MXFP8RecipeAssertion(RecipeAssertionBase): - ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) - with autocast(enabled=True, recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(ds) + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1] + for tensor_source in TensorSource: + assert quantize_config.get_scaling_mode(tensor_source) == ScalingMode.MXFP8_1D_SCALING - self._check_default_state() + def assert_quantizers(self, ref_recipe: MXFP8BlockScaling, quantizer, tensor_source): + assert quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING + assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format) - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_autocast_current_scaling(self): - self._check_default_state() - with autocast(enabled=False, recipe=Float8CurrentScaling(), mesh_resource=MeshResource()): - self._check_default_state() +class NVFP4RecipeAssertion(RecipeAssertionBase): - self._check_default_state() + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[1] + for tensor_source in TensorSource: + target_scaling_mode = ( + ScalingMode.NVFP4_2D_SCALING + if (not ref_recipe.disable_2d_quantization) and tensor_source == TensorSource.KERNEL + else ScalingMode.NVFP4_1D_SCALING + ) + assert quantize_config.get_scaling_mode(tensor_source) == target_scaling_mode + assert quantize_config.DISABLE_STOCHASTIC_ROUNDING == ref_recipe.disable_stochastic_rounding + assert quantize_config.DISABLE_RHT == ref_recipe.disable_rht + assert quantize_config.DISABLE_2D_QUANTIZATION == ref_recipe.disable_2d_quantization + + def assert_quantizers(self, ref_recipe: NVFP4BlockScaling, quantizer, tensor_source): + if tensor_source == TensorSource.KERNEL and not ref_recipe.disable_2d_quantization: + assert quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING + else: + assert quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + + if ref_recipe.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD: + assert quantizer.stochastic_rounding_rng_state is None + else: + assert quantizer.stochastic_rounding_rng_state is not None + + expected_rht = ( + quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE} + and not ref_recipe.disable_rht + ) + assert quantizer.use_rht == expected_rht - cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) - with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_current_scaling(cs) - self._check_default_state() +class TestFP8Functions(unittest.TestCase): - cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) - with autocast(enabled=True, recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_current_scaling(cs) + def _check_default_state(self): + self.assertEqual(get_global_quantize_recipe(), None) - self._check_default_state() + def _test_recipe(self, quantization_recipe: Recipe, cls: RecipeAssertionBase): + """Tests a quantization recipe by verifying its behavior in both autocast and direct application contexts.""" + assert_context_func = cls().assert_context + assert_quantizer_func = partial(cls().assert_quantizers, quantization_recipe) + self._test_recipe_autocast(quantization_recipe, assert_context_func, assert_quantizer_func) + self._test_recipe_direct(quantization_recipe, assert_quantizer_func) - @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) - def test_autocast_mxfp8_block_scaling(self): + def _test_recipe_autocast( + self, quantization_recipe, assert_context_func, assert_quantizer_func + ): + """Tests a quantization recipe within an autocast context by verifying the quantize config and quantizers in a test module.""" self._check_default_state() - - with autocast(enabled=False, recipe=MXFP8BlockScaling(), mesh_resource=MeshResource()): + with autocast(enabled=False, recipe=quantization_recipe, mesh_resource=MeshResource()): self._check_default_state() + with autocast(enabled=True, recipe=quantization_recipe, mesh_resource=MeshResource()): + quantize_config = self._get_global_quantize_config() + assert_context_func(quantization_recipe, quantize_config) + self._test_quantizer_in_model(assert_quantizer_func) + self._check_default_state() + def _test_recipe_direct(self, quantization_recipe, assert_quantizer_func): + """Tests a quantization recipe by directly passing it to a test module and verifying the quantizers.""" + self._check_default_state() + self._test_quantizer_in_model(assert_quantizer_func, direct_recipe=quantization_recipe) self._check_default_state() - bs = MXFP8BlockScaling() - with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_mxfp8_scaling(bs) + def _test_quantizer_in_model(self, assert_quantizer_func, direct_recipe=None): + """Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary. - self._check_default_state() + Args: + assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None. + direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts. + """ + x = jnp.ones((), dtype=jnp.float32) + test_module = TestModule(assertion_func=assert_quantizer_func, direct_recipe=direct_recipe) + param_key, sr_key = jax.random.split(jax.random.PRNGKey(0)) + rngs = {"params": param_key, "sr_rng": sr_key} + variables = test_module.init(rngs, x) - @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) - def test_autocast_nvfp4_block_scaling(self): - self._check_default_state() + jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs) - with autocast(enabled=False, recipe=NVFP4BlockScaling(), mesh_resource=MeshResource()): - self._check_default_state() + def _get_global_quantize_config(self): + quantization_recipe = get_global_quantize_recipe() + assert quantization_recipe is not None, "No global quantization recipe set" + quantize_config = get_quantize_config_with_recipe(quantization_recipe) + assert ( + quantize_config.is_fp8_enabled() + ), "Quantization not enabled in global quantize config" + return quantize_config - self._check_default_state() + @unittest.skipIf(not is_fp8_supported, reason=reason) + def test_autocast_delayed_scaling(self): + self._test_recipe( + quantization_recipe=DelayedScaling(), + cls=DelayedScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=DelayedScaling( + margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1 + ), + cls=DelayedScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=DelayedScaling( + margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1 + ), + cls=DelayedScalingRecipeAssertion, + ) - bs = NVFP4BlockScaling() - with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_nvfp4_scaling(bs) - self._compare_nvfp4_scaling_quantizers(bs) + @unittest.skipIf(not is_fp8_supported, reason=reason) + def test_autocast_current_scaling(self): + self._test_recipe( + quantization_recipe=Float8CurrentScaling(), + cls=CurrentScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3), + cls=CurrentScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID), + cls=CurrentScalingRecipeAssertion, + ) - bs = NVFP4BlockScaling( - disable_stochastic_rounding=True, - disable_rht=True, - disable_2d_quantization=True, + @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + def test_autocast_mxfp8_block_scaling(self): + self._test_recipe( + quantization_recipe=MXFP8BlockScaling(), + cls=MXFP8RecipeAssertion, ) - with autocast(enabled=True, recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_nvfp4_scaling(bs) - self._compare_nvfp4_scaling_quantizers(bs) - self._check_default_state() + @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) + def test_autocast_nvfp4_block_scaling(self): + self._test_recipe( + quantization_recipe=NVFP4BlockScaling(), + cls=NVFP4RecipeAssertion, + ) + self._test_recipe( + quantization_recipe=NVFP4BlockScaling( + disable_stochastic_rounding=True, + disable_rht=True, + disable_2d_quantization=True, + ), + cls=NVFP4RecipeAssertion, + ) class TestJaxprAndHlo: diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 96b73909e..556b58719 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -221,7 +221,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F """ Helper function to manage primitive states by name without modifying environment variables. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. - This helper is used in the get_quantize_config().initialize() methods. + This helper is used in the get_quantize_config_with_recipe().initialize() methods. Args: enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9ffec2c6a..c00b816f2 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -38,12 +38,13 @@ ScalingMode, Quantizer, GroupedQuantizer, - get_quantize_config, QuantizerSet, QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, + get_quantize_config_with_recipe, + get_global_quantize_recipe, ) from .misc import get_padded_spec, is_all_reduce_in_float32 from ..sharding import ( @@ -1246,7 +1247,7 @@ def _te_gemm( fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, - use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, + use_split_accumulator: bool = None, transpose_batch_sequence: bool = False, collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: @@ -1258,6 +1259,13 @@ def _te_gemm( DeprecationWarning, ) + if use_split_accumulator is None: + # TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also + # use context of the GEMM type so we can decide between fprop, dgrad, and wgrad + use_split_accumulator = get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).FP8_2X_ACC_FPROP + # Prepare non-quantized GEMM operands lhs_data = lhs rhs_data = rhs @@ -1720,10 +1728,15 @@ def _jax_gemm_impl(lhs, rhs): assert ( rhs.scaling_mode == lhs.scaling_mode ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" + + # TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also + # use context of the GEMM type so we can decide between fprop, dgrad, and wgrad + use_split_accumulator = get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).FP8_2X_ACC_FPROP + precision = ( - jax.lax.Precision.HIGHEST - if get_quantize_config().FP8_2X_ACC_FPROP - else jax.lax.Precision.DEFAULT + jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a0e1a6406..d16dab6d6 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -820,7 +820,7 @@ def _quantize_dbias_impl( amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, ) - scale = compute_scale_from_amax(amax, quantizer.q_dtype) + scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # Make sure to reset amax to zeros for DelayedScaling @@ -1227,7 +1227,7 @@ def grouped_quantize( ) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): - tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) + tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) scale = scale.at[i].set(tmp_scale[0]) is_tensor_scaling = quantizer.scaling_mode in ( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index c497775e0..613455b6c 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -27,7 +27,6 @@ with_sharding_constraint_by_logical_axes, is_fp8_gemm_with_all_layouts_supported, TensorUsage, - get_quantize_config, ) @@ -95,7 +94,7 @@ def dense( if transpose_batch_sequence: warnings.warn("transpose_batch_sequence is not well tested, use with caution!") - if not get_quantize_config().is_fp8_enabled(): + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 934af3d18..b5f159022 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -33,10 +33,11 @@ ) from ..quantize import ( QuantizerFactory, - get_quantize_config, + get_global_quantize_recipe, QuantizeMetaSet, TensorSource, get_quantize_config_with_recipe, + noop_quantizer_set, ) PRNGKey = Any @@ -355,17 +356,17 @@ def generate_quantizer_set( Generate a set of FP8 meta for a GEMM. """ + if fp8_recipe is None: + fp8_recipe = get_global_quantize_recipe() + + quantize_config = get_quantize_config_with_recipe(fp8_recipe) + collection_name = ( variable_collection if variable_collection is not None - else get_quantize_config().COLLECTION_NAME + else quantize_config.COLLECTION_NAME ) - if fp8_recipe is None: - quantize_config = get_quantize_config() - else: - quantize_config = get_quantize_config_with_recipe(fp8_recipe) - x_meta = quantize_config.get_quantize_flax_meta( self, collection_name, postfix, TensorSource.X, "x" ) @@ -492,7 +493,11 @@ def __call__(self, inputs: Array) -> Array: self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + quantizer_set = self.generate_quantizer_set( + quantization_checkpoint_name=self.quantization_checkpoint_name + ) + + if quantizer_set == noop_quantizer_set: kernel = kernel.astype(input_dtype) if self.use_bias: @@ -505,9 +510,6 @@ def __call__(self, inputs: Array) -> Array: else: bias = None - quantizer_set = self.generate_quantizer_set( - quantization_checkpoint_name=self.quantization_checkpoint_name - ) contract_ind = tuple(range(0, len(axis))) y = dense( inputs, @@ -712,7 +714,7 @@ def __call__(self, inputs: Array) -> Array: ) fuse_layernorm = ( - get_quantize_config().is_fp8_enabled() + quantizer_set != noop_quantizer_set and not self.return_layernorm_output and self.enable_layernorm ) @@ -763,7 +765,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape, self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + if quantizer_set == noop_quantizer_set: kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -1042,7 +1044,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: # TODO(Phuong): use fuse_layernorm for high-precision # when NoOpQuantizer and Tensor are implemented fuse_layernorm = ( - get_quantize_config().is_fp8_enabled() + ffn1_quantizer_set != noop_quantizer_set and not self.return_layernorm_output and self.enable_layernorm ) @@ -1128,7 +1130,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + if ffn1_quantizer_set == noop_quantizer_set: kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] @@ -1140,7 +1142,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_2_shape, self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + if ffn2_quantizer_set == noop_quantizer_set: kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index b9482b7bd..14726553f 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -23,7 +23,6 @@ noop_quantizer_set, with_sharding_constraint_by_logical_axes, TensorUsage, - get_quantize_config, ) @@ -73,7 +72,7 @@ def layernorm_dense( - Quantization is applied to both the normalized input and kernel """ - if not get_quantize_config().is_fp8_enabled(): + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 2fd0f07d6..47fed6c3a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -28,7 +28,6 @@ QuantizerSet, noop_quantizer_set, TensorUsage, - get_quantize_config, ) @@ -114,7 +113,7 @@ def layernorm_mlp( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" - if not get_quantize_config().is_fp8_enabled(): + if quantizer_sets == (noop_quantizer_set, noop_quantizer_set): input_dtype = x.dtype kernel_1 = kernel_1.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index d5093e70e..6358edf46 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -46,7 +46,7 @@ from .device_utils import get_device_compute_capability __all__ = [ - "get_quantize_config", + "get_global_quantize_recipe", "get_quantize_config_with_recipe", "autocast", "fp8_autocast", @@ -475,7 +475,12 @@ def get_quantize_flax_meta( (self.AMAX_HISTORY_LEN,), jnp.float32, ).value - return QuantizeMeta(scale=scale, amax_history=amax_history) + return QuantizeMeta( + margin=self.MARGIN, + amax_compute_algo=self.AMAX_COMPUTE_ALGO, + scale=scale, + amax_history=amax_history, + ) class CurrentScalingQuantizeConfig(BaseQuantizeConfig): @@ -669,14 +674,6 @@ def get_quantize_flax_meta( ) -_QUANTIZE_CONFIG = NoOpQuantizeConfig() - - -def get_quantize_config(): - """Global instance of BaseQuantizeConfig set by autocast context.""" - return _QUANTIZE_CONFIG - - def get_quantize_config_class( fp8_recipe: Recipe, ) -> Type[BaseQuantizeConfig]: @@ -687,6 +684,8 @@ def get_quantize_config_class( Returns: The quantization config class corresponding to the given recipe. """ + if fp8_recipe is None: + return NoOpQuantizeConfig if isinstance(fp8_recipe, DelayedScaling): return DelayedScalingQuantizeConfig if isinstance(fp8_recipe, MXFP8BlockScaling): @@ -701,10 +700,23 @@ def get_quantize_config_class( def get_quantize_config_with_recipe(fp8_recipe: Recipe): """Get the quantization configuration object based on the FP8 recipe.""" config = get_quantize_config_class(fp8_recipe)() - config.initialize_from_recipe(fp8_recipe) + if fp8_recipe is not None: + config.initialize_from_recipe(fp8_recipe) return config +_GLOBAL_RECIPE: Optional[Recipe] = None + + +def get_global_quantize_recipe() -> Optional[Recipe]: + """Get the global quantization recipe if set. + + Returns: + The global quantization recipe or None if not set. + """ + return _GLOBAL_RECIPE + + @contextmanager def autocast( enabled: bool = False, @@ -751,22 +763,21 @@ def autocast( if recipe is None: recipe = DelayedScaling() - global _QUANTIZE_CONFIG + global _GLOBAL_RECIPE - old_quantize_config = _QUANTIZE_CONFIG + old_global_recipe = _GLOBAL_RECIPE - _QUANTIZE_CONFIG = NoOpQuantizeConfig() + _GLOBAL_RECIPE = None try: with global_shard_guard(mesh_resource): if enabled: - _QUANTIZE_CONFIG = get_quantize_config_class(recipe)() - is_supported, reason = _QUANTIZE_CONFIG.is_supported() + _GLOBAL_RECIPE = recipe + is_supported, reason = get_quantize_config_class(_GLOBAL_RECIPE)().is_supported() assert is_supported, reason - _QUANTIZE_CONFIG.initialize_from_recipe(recipe) yield finally: - _QUANTIZE_CONFIG = old_quantize_config + _GLOBAL_RECIPE = old_global_recipe @contextmanager diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index adff31748..4edc18779 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -28,7 +28,7 @@ NoScaleTensor, ) from .helper import ( - get_quantize_config, + get_global_quantize_recipe, get_quantize_config_with_recipe, AmaxComputeAlgo, TensorSource, @@ -50,7 +50,7 @@ def compute_scale_from_amax( - amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None + amax: jnp.ndarray, q_dtype: jnp.dtype, margin: float, scale: Optional[jnp.ndarray] = None ) -> jnp.ndarray: """Compute scale from amax value. @@ -64,7 +64,7 @@ def compute_scale_from_amax( fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) if scale is None: scale = jnp.ones((1,)) - sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) + sf = (fp8_max / amax) / (2**margin) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" @@ -223,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer): Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING q_layout: Quantization axis (default: ROWWISE_COLWISE) + data_layout: Data layout string (default: "NT") """ scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING @@ -254,8 +255,7 @@ def _quantize_func( compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) - fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) - scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) + scale = compute_scale_from_amax(amax, self.q_dtype, margin=0.0) scaled_x = x.data.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) @@ -327,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING q_layout: Quantization axis (default: ROWWISE_COLWISE) + data_layout: Data layout string (default: "NT") + margin: Margin value for scale computation + amax_compute_algo: Algorithm for computing amax scale: Current scaling factor amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING - q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE + margin: float = 0.0 + amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) - amax_history: jnp.ndarray = field( - default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32) - ) + amax_history: jnp.ndarray = field(default_factory=lambda: jnp.zeros((1024,), jnp.float32)) + + def __post_init__(self): + assert self.margin is not None, "margin must be specified" + assert self.amax_compute_algo is not None, "amax_compute_algo must be specified" + assert self.amax_history is not None, "amax_history must be specified" def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -352,6 +358,8 @@ def tree_flatten(self): self.q_layout, self.data_layout, self.checkpoint_name, + self.margin, + self.amax_compute_algo, ) return (children, aux_data) @@ -407,12 +415,14 @@ def _update_amax_history(amax_history, new_amax): Returns: Updated AMAX history """ - amax_history = amax_history.at[0].set(new_amax[0]) + amax_history = amax_history.at[0].set(new_amax.reshape((1,))[0]) return amax_history @staticmethod - @partial(jax.jit, static_argnums=(2,)) - def _compute_scale(amax_history, scale, q_dtype): + @partial(jax.jit, static_argnums=(2, 3, 4)) + def _compute_scale( + amax_history, scale, q_dtype, amax_compute_algo: AmaxComputeAlgo, margin: float + ): """Compute new scale based on AMAX history. Args: @@ -424,12 +434,12 @@ def _compute_scale(amax_history, scale, q_dtype): Updated scale value """ # 2. Calculate the current scale - if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + if amax_compute_algo is AmaxComputeAlgo.MAX: amax = jnp.max(amax_history, axis=-1, keepdims=True) else: amax = amax_history[0:1] - return compute_scale_from_amax(amax, q_dtype, scale=scale) + return compute_scale_from_amax(amax, q_dtype, margin=margin, scale=scale) @staticmethod @jax.jit @@ -453,7 +463,9 @@ def update(self, new_amax: jnp.ndarray): new_amax: New maximum absolute value to add to history """ amax_history = self._update_amax_history(self.amax_history, new_amax) - self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype) + self.scale = self._compute_scale( + amax_history, self.scale, self.q_dtype, self.amax_compute_algo, self.margin + ) self.amax_history = self._roll_and_reset_amax_history(amax_history) @@ -1124,6 +1136,7 @@ def _create_set( bwd_dtype, is_2x2x, n_groups, + is_inference_mode=False, checkpoint_name: Optional[str] = None, **kwargs, ) -> QuantizerSet: @@ -1137,6 +1150,7 @@ def _create_set( bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization n_groups + is_inference_mode: Whether to create quantizers for inference mode. This option is not fully supported yet checkpoint_name: Optional name for checkpointing quantizations **kwargs: Additional arguments for quantizer initialization @@ -1149,7 +1163,7 @@ def _create_set( q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE if kernel_scaling_mode.is_1d_block_scaling(): q_layout_kernel = QuantizeLayout.COLWISE - if get_quantize_config().INFERENCE_MODE: + if is_inference_mode: q_layout_dgrad = None if "quantize_meta_set" in kwargs: @@ -1206,10 +1220,10 @@ def create_set( Args: n_quantizer_sets: Number of quantizer sets to create - scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode - fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE - bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE - is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X + scaling_mode: Scaling mode to use, default is get the scaling mode from the specified or global recipe + fwd_dtype: Data type for forward pass, default is get the fwd dtype from the specified or global recipe + bwd_dtype: Data type for backward pass, default is get the bwd dtype from the specified or global recipe + is_2x2x: Whether to use 2x2x quantization, default is determined based on the specified or global recipe n_groups: checkpoint_name: Optional name for checkpointing quantizations fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. @@ -1226,25 +1240,46 @@ def create_set( " scaling mode differs between x, kernel, and grad in the quantizer set." ) + # TODO(jberchtold): Currently this is a limitation because we only support automatically populating quantizer fields based on a given recipe when using Flax. In the generic quantizer logic, we cannot assume Flax is being used, so we require the user to provide the quantize_meta_set created by quantize_config.get_quantize_flax_meta() or the same data created by themselves if they are passing a recipe here directly. + assert ( + fp8_recipe is None or "quantize_meta_set" in kwargs + ), "When fp8_recipe is specified, quantize_meta_set must be provided in kwargs." + + if fp8_recipe is None: + fp8_recipe = get_global_quantize_recipe() + if fp8_recipe is not None: + assert scaling_mode is None, ( + "scaling_mode should not be specified when fp8_recipe is provided either directly" + " or through an autocast context." + ) + assert fwd_dtype is None, ( + "fwd_dtype should not be specified when fp8_recipe is provided either directly or" + " through an autocast context." + ) + assert bwd_dtype is None, ( + "bwd_dtype should not be specified when fp8_recipe is provided either directly or" + " through an autocast context." + ) quantize_config = get_quantize_config_with_recipe(fp8_recipe) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) fwd_dtype = quantize_config.FWD_DTYPE bwd_dtype = quantize_config.BWD_DTYPE + is_inference_mode = quantize_config.INFERENCE_MODE else: if scaling_mode is not None: x_scaling_mode = scaling_mode kernel_scaling_mode = scaling_mode grad_scaling_mode = scaling_mode else: - x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) - kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) - grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + # TODO(jberchtold): make a way to explicitly pass a no scaling recipe here if we need other quantization config attributes in the future since NoOpQuantizeConfig already exists, we just can't use it here with direct recipe passing because we cannot differentiate between fp8_recipe=None meaning no recipe specified vs explicitly no quantization desired. + x_scaling_mode = ScalingMode.NO_SCALING + kernel_scaling_mode = ScalingMode.NO_SCALING + grad_scaling_mode = ScalingMode.NO_SCALING + is_inference_mode = False - fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE - bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: # TODO(Jeremy): check x, kernel, grad separately for 2x if x_scaling_mode.is_1d_block_scaling(): @@ -1253,7 +1288,6 @@ def create_set( is_2x2x = not is_fp8_gemm_with_all_layouts_supported() else: # NO_SCALING ignores is_2x2x for now is_2x2x = False - is_inference_mode = get_quantize_config().INFERENCE_MODE assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] @@ -1267,6 +1301,7 @@ def create_set( bwd_dtype=bwd_dtype, is_2x2x=is_2x2x, n_groups=n_groups, + is_inference_mode=is_inference_mode, checkpoint_name=checkpoint_name, **kwargs, ) From c525760538b5cb1b77f3d93ab2c98d75b9453f43 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 14 Nov 2025 20:28:39 +0100 Subject: [PATCH 127/141] [PyTorch] Activation offloading refactor (#1762) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * init Signed-off-by: Pawel Gadzinski * offloading Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * all types Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * typo Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * init Signed-off-by: Pawel Gadzinski * api change Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * refactor Signed-off-by: Pawel Gadzinski * tests Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * example Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * cpu offload + debug warning Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * change empty_like implementation to use make_like Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * main_grad fix Signed-off-by: Pawel Gadzinski * manual synchornization Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * old path Signed-off-by: Pawel Gadzinski * remove example Signed-off-by: Pawel Gadzinski * api changes Signed-off-by: Pawel Gadzinski * reverted grouped linear Signed-off-by: Pawel Gadzinski * make odl code path work for modules Signed-off-by: Pawel Gadzinski * attention old code path Signed-off-by: Pawel Gadzinski * legacy tests Signed-off-by: Pawel Gadzinski * legacy tests Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * updated code path Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/quantized_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nvfp4 support Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update tests/pytorch/test_cpu_offloading.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> * small fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * docs change Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Paweł Gadziński <62263673+pggPL@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: root Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 3 +- tests/pytorch/test_cpu_offloading.py | 884 ++++++++--- tests/pytorch/test_cpu_offloading_v1.py | 215 +++ .../dot_product_attention/backends.py | 69 +- .../dot_product_attention.py | 8 - .../pytorch/attention/multi_head_attention.py | 5 +- transformer_engine/pytorch/cpu_offload.py | 1354 +++++++++-------- transformer_engine/pytorch/cpu_offload_v1.py | 743 +++++++++ .../pytorch/module/grouped_linear.py | 8 +- .../pytorch/module/layernorm_linear.py | 19 +- .../pytorch/module/layernorm_mlp.py | 21 +- transformer_engine/pytorch/module/linear.py | 11 +- .../pytorch/optimizers/fused_adam.py | 4 +- .../pytorch/quantized_tensor.py | 96 +- .../pytorch/tensor/float8_blockwise_tensor.py | 10 +- .../pytorch/tensor/float8_tensor.py | 36 +- .../pytorch/tensor/mxfp8_tensor.py | 36 +- .../pytorch/tensor/nvfp4_tensor.py | 45 +- .../float8_blockwise_tensor_storage.py | 7 + .../tensor/storage/float8_tensor_storage.py | 9 + .../tensor/storage/mxfp8_tensor_storage.py | 7 + 21 files changed, 2714 insertions(+), 876 deletions(-) create mode 100644 tests/pytorch/test_cpu_offloading_v1.py create mode 100644 transformer_engine/pytorch/cpu_offload_v1.py diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index b23ce3b6c..e1ce68009 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -42,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 64da83a21..c5b4b48b6 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,27 +2,41 @@ # # See LICENSE for license information. +import random import contextlib -import gc -import os -from typing import Iterable, Optional - import pytest +import os import torch - +from typing import Optional, List +from transformer_engine.pytorch.cpu_offload import ( + get_cpu_offload_context, + OffloadableLayerState, + DefaultOffloadSynchronizer, + start_offload, + mark_not_offload, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends -from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported -from utils import ModelConfig, get_available_attention_backends +from utils import ModelConfig +import transformer_engine_torch as tex # Check supported quantization schemes -fp8_available = te.is_fp8_available() -mxfp8_available = te.is_mxfp8_available() +fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() -quantization_recipes: Optional[recipe.Recipe] = [None] +quantization_recipes: List[Optional[recipe.Recipe]] = [None] if fp8_available: quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) +if fp8_block_scaling_available: + quantization_recipes.append(recipe.Float8BlockScaling()) +if mxfp8_available: + quantization_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: + quantization_recipes.append(recipe.NVFP4BlockScaling()) + model_config = { "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), @@ -32,181 +46,709 @@ NUM_LAYERS = model_config["small"].num_layers EPSILON = model_config["small"].eps -# Flash attention saves some internal tensor for the backward pass -# that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN") == "0" +# Disable garbage collection to tests if there are reference cycles. +# We do not want them, because they can result in CUDA out of memory errors. +import gc -# Offloading is supported for attention only for fused and flash attention backends, -# so the use of bfloat16 is required. -# -# For the TransformerLayer, activation offloading with dropout is not supported, -# so we set hidden_dropout to 0.0. -model_types = { - "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), - "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), - "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), - "multihead_attention": lambda: te.MultiheadAttention( - SIZE, NUM_HEADS, params_dtype=torch.bfloat16 - ), - "transformer_layer": lambda: te.TransformerLayer( - SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 - ), - "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), - "layernorm_mlp_ops": lambda: te.ops.Sequential( - te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), - te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), - te.ops.GELU(), - te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), - ), -} +gc.disable() + + +class Utils: + tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16) + _B = 64 + _S = 256 + _H = 4 + _D = 256 + + @staticmethod + def long_job(stream: Optional[torch.cuda.Stream] = None): + NUM_ITERS = 6000 + if stream is None: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + for i in range(NUM_ITERS): + Utils.tensor1.normal_() + + @staticmethod + def measure_time(func): + import time + + torch.cuda.synchronize() + start = time.time() + func() + torch.cuda.synchronize() + end = time.time() + return (end - start) * 1000 + + @staticmethod + def get_cuda_memory_mb(): + return torch.cuda.memory_allocated() / (1024**2) + + @staticmethod + def get_max_cuda_memory_mb(): + return torch.cuda.max_memory_allocated() / (1024**2) + + @staticmethod + def get_cpu_memory_mb() -> float: + import psutil, os + + return psutil.Process(os.getpid()).memory_info().rss / (1024**2) + + @staticmethod + def get_layer_names(): + return [ + "linear", + "layernorm_linear", + "layernorm_mlp", + "grouped_linear", + "multihead_attention", + "transformer_layer", + "linear_op", + "layernorm_mlp_ops", + ] + + @staticmethod + def create_layer(layer_type: str): + if layer_type == "linear": + return te.Linear(Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "layernorm_linear": + return te.LayerNormLinear(Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "layernorm_mlp": + return te.LayerNormMLP(Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "multihead_attention": + return te.MultiheadAttention( + Utils._D, Utils._H, attention_dropout=0.0, params_dtype=torch.bfloat16 + ) + elif layer_type == "grouped_linear": + return te.GroupedLinear(Utils._H, Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "transformer_layer": + return te.TransformerLayer( + Utils._D, + Utils._D, + Utils._H, + attention_dropout=0.0, + hidden_dropout=0.0, + params_dtype=torch.bfloat16, + ) + elif layer_type == "linear_op": + return te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16) + elif layer_type == "layernorm_mlp_ops": + return te.ops.Sequential( + te.ops.LayerNorm(Utils._D, dtype=torch.bfloat16), + te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16), + ) + else: + raise ValueError(f"Unknown layer type: {layer_type}") + + @staticmethod + def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) -> torch.Tensor: + shape = (Utils._B, Utils._S, Utils._D) + tensor = torch.randn(shape, device="cuda", dtype=torch.bfloat16) + if recipe is None: + tensor = tensor.requires_grad_() if requires_grad else tensor + return tensor + elif recipe.delayed(): + quantizer = te.tensor.float8_tensor.Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1.0], device="cuda"), + amax=torch.tensor([1.0], device="cuda"), + ) + return quantizer(tensor) + elif recipe.float8_current_scaling(): + quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, device="cuda" + ) + return quantizer(tensor) + elif recipe.float8_block_scaling(): + quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True + ) + return quantizer(tensor) + elif recipe.mxfp8(): + quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + return quantizer(tensor) + elif recipe.nvfp4(): + quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer() + return quantizer(tensor) + + @staticmethod + def create_recipe_ctx(recipe: Optional[recipe.Recipe]): + if recipe is None: + return lambda: contextlib.nullcontext() + else: + return lambda: te.fp8_autocast(fp8_recipe=recipe) + + @staticmethod + def get_tensor_size_mb(tensor): + if tensor is None: + return 0 + if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): + return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + else: + return tensor.numel() * tensor.element_size() / (1024**2) + + @staticmethod + def memory_leak_check(): + # Should be called before each test. + # Only cublas workspaces and some global tensors are allowed to be allocated. + # All other allocations should be released. + # This is a simple check to catch memory leaks. + if Utils.get_cuda_memory_mb() > 1000: + memory_num = Utils.get_cuda_memory_mb() + import gc + + gc.collect() # We want next test to be run with clean state. + gc.disable() + raise RuntimeError(f"Memory leak: {memory_num} MB") + + +class TestsOffloadableLayerState: + @pytest.mark.parametrize("random_num_tensors", [True, False]) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_general(self, random_num_tensors, recipe): + """ + Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, + for each layer offload random number of random tensors. + Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors. + """ + Utils.memory_leak_check() + NUM_ITERATIONS = 10 + + stream = torch.cuda.Stream() + + offload_layer_state = OffloadableLayerState( + offload_stream=stream, + ) + for _ in range(NUM_ITERATIONS): + original_tensors = [] + tensors_ids = [] + NUM_TENSORS = random.choice([1, 20]) if random_num_tensors else 1 + for _ in range(NUM_TENSORS): + tensor = Utils.create_tensor(recipe) + original_tensors.append(tensor) + tensor_id = offload_layer_state.push_tensor(tensor) + assert tensor.device.type == "cuda" + tensors_ids.append(tensor_id) + + offload_layer_state.start_offload() + offload_layer_state.release_activation_forward_gpu_memory() + offload_layer_state.start_reload() + + for j in range(len(tensors_ids)): + tensor_gpu = offload_layer_state.pop_tensor(tensors_ids[j]) + assert tensor_gpu.device.type == "cuda" + assert tensor_gpu.shape == original_tensors[j].shape + assert tensor_gpu.dtype == original_tensors[j].dtype + torch.testing.assert_close(tensor_gpu, original_tensors[j]) + offload_layer_state.release_all_memory() + torch.cuda.synchronize() + + def test_offload_base_tensor(self): + Utils.memory_leak_check() + stream = torch.cuda.Stream() + offload_layer_state = OffloadableLayerState( + offload_stream=stream, + ) + init_cuda_memory = Utils.get_cuda_memory_mb() + x = Utils.create_tensor(None) + x_size = Utils.get_tensor_size_mb(x) + x_1 = x[::2] + x_2 = x[1::2] + + start_offload(x_1, offload_base_tensor=True) + start_offload(x_2, offload_base_tensor=True) + x1_id = offload_layer_state.push_tensor(x_1) + x2_id = offload_layer_state.push_tensor(x_2) + del x_1, x_2 + offload_layer_state.start_offload() + offload_layer_state.release_activation_forward_gpu_memory() + + assert offload_layer_state.get_offloaded_total_size_mb() == pytest.approx(x_size, 0.1) + + offload_layer_state.start_reload() + x_1 = offload_layer_state.pop_tensor(x1_id) + x_2 = offload_layer_state.pop_tensor(x2_id) + assert x_1.device.type == "cuda" + assert x_2.device.type == "cuda" + + assert torch.allclose(x_1, x[::2]) + assert torch.allclose(x_2, x[1::2]) + del x + + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1) + + +class TestsDefaultOffloadSynchronizer: + @pytest.mark.parametrize("random_num_tensors", [True, False]) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_general(self, random_num_tensors, recipe): + """ + Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, + for each layer offload random number of random tensors. + Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors. + """ + Utils.memory_leak_check() + NUM_LAYERS = 10 + NUM_ITERATIONS = 10 + + offload_synchronizer = DefaultOffloadSynchronizer( + num_layers=NUM_LAYERS, + num_offloaded_layers=NUM_LAYERS - 1, + ) + + for _ in range(NUM_ITERATIONS): + original_tensors = [] + tensors_ids = [] + layer_ids = [] + + for i in range(NUM_LAYERS): + NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1 + layer_tensors = [] + layer_tensors_ids = [] + layer_id = offload_synchronizer.fwd_step() + for _ in range(NUM_LAYER_TENSORS): + tensor = Utils.create_tensor(recipe) + layer_tensors.append(tensor) + tensor_id = offload_synchronizer.push_tensor(tensor) + assert tensor.device.type == "cuda" + layer_tensors_ids.append(tensor_id) + layer_ids.append(layer_id) + tensors_ids.append(layer_tensors_ids) + original_tensors.append(layer_tensors) + for i in range(NUM_LAYERS - 1, -1, -1): + offload_synchronizer.bwd_step(layer_ids[i]) + for j in range(len(tensors_ids[i])): + tensor_gpu = offload_synchronizer.pop_tensor(tensors_ids[i][j]) + assert tensor_gpu.device.type == "cuda" + assert tensor_gpu.shape == original_tensors[i][j].shape + assert tensor_gpu.dtype == original_tensors[i][j].dtype + torch.testing.assert_close(tensor_gpu, original_tensors[i][j]) + offload_synchronizer.finish_part_of_bwd() + torch.cuda.synchronize() + + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_memory(self, recipe): + torch.cuda.synchronize() + Utils.memory_leak_check() + NUM_LAYERS = 10 + + torch.cuda.reset_peak_memory_stats() + + offload_synchronizer = DefaultOffloadSynchronizer( + num_layers=NUM_LAYERS, + num_offloaded_layers=NUM_LAYERS - 1, + ) -def _make_input() -> torch.Tensor: - """Generate random input tensor.""" - return torch.randn( - (128, SIZE, SIZE), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - - -def _warmup_model( - modules: Iterable[torch.nn.Module], - quantization_recipe: Optional[recipe.Recipe], -) -> None: - """Perform forward and backward pass""" - tensor = _make_input() - for module in modules: - with te.autocast( - enabled=quantization_recipe is not None, - recipe=quantization_recipe, + init_cuda_memory = Utils.get_cuda_memory_mb() + + tensor_ids = [] + + torch.cuda.synchronize() + for _ in range(NUM_LAYERS): + offload_synchronizer.fwd_step() + tensor = Utils.create_tensor(recipe) + tensor_size = Utils.get_tensor_size_mb(tensor) + tensor_id = offload_synchronizer.push_tensor(tensor) + assert tensor.device.type == "cuda" + tensor_ids.append(tensor_id) + del tensor, tensor_id + torch.cuda.synchronize() + + if recipe is None: + assert Utils.get_max_cuda_memory_mb() == pytest.approx( + init_cuda_memory + tensor_size, 0.1 + ) + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1) + + for i in range(NUM_LAYERS - 1, -1, -1): + offload_synchronizer.bwd_step(i) + tensor_gpu = offload_synchronizer.pop_tensor(tensor_ids[i]) + assert tensor_gpu.device.type == "cuda" + del tensor_gpu, tensor_ids[i] + offload_synchronizer.finish_part_of_bwd() + + del tensor_ids + torch.cuda.synchronize() + + if recipe is None: + assert Utils.get_max_cuda_memory_mb() == pytest.approx( + init_cuda_memory + tensor_size, 0.1 + ) + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_multiple_tensor_offload(self, recipe): + Utils.memory_leak_check() + init_cpu_memory = Utils.get_cpu_memory_mb() + init_cuda_memory = Utils.get_cuda_memory_mb() + offload_synchronizer = DefaultOffloadSynchronizer( + num_layers=2, + num_offloaded_layers=1, + ) + x1 = Utils.create_tensor(recipe) + x_size = Utils.get_tensor_size_mb(x1) + offload_synchronizer.fwd_step() + offload_synchronizer.push_tensor(x1) + offload_synchronizer.push_tensor(x1) + offload_synchronizer.push_tensor(x1) + offload_synchronizer.fwd_step() + # Only one copy of tensor on cpu is allocated. + assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1) + del x1 + offload_synchronizer.bwd_step(1) + offload_synchronizer.bwd_step(0) + offload_synchronizer.finish_part_of_bwd() + + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + + +class TestTELayers: + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_sanity(self, layer_type, recipe): + Utils.memory_leak_check() + + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() ): - tensor = module(tensor) - tensor.sum().backward() - - -def _estimate_cached_weight_size( - model_name: str, - modules: Iterable[torch.nn.Module], - quantization_recipe: Optional[recipe.Recipe], -) -> float: - """Calculate the memory (in MiB) needed for weight caching.""" - - # The weight params are cached directly for unquantized compute - if quantization_recipe is None: - return 0 - - # Count number of weight param elements - param_elements = 0 - for module in modules: - for param in module.parameters(): - if param.dim() == 2: - param_elements += param.numel() - - # FP8 tensor-scaling caches one byte per element - if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): - if not is_non_tn_fp8_gemm_supported() and model_name not in ( - "linear_op", - "layernorm_mlp_ops", + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + + recipe_ctx = Utils.create_recipe_ctx(recipe) + init_cuda_memory = Utils.get_cuda_memory_mb() + OFFLOAD_LAYERS = 6 + NUM_LAYERS = 10 + offload_ctx, sync_function = get_cpu_offload_context( + enabled=True, + num_layers=OFFLOAD_LAYERS, + model_layers=NUM_LAYERS, + ) + layers = [Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)] + inp = Utils.create_tensor(None) + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + out = inp + for i in range(NUM_LAYERS): + with offload_ctx, recipe_ctx(): + # Ops-based layers don't support is_first_microbatch parameter + if layer_type in ["linear_op", "layernorm_mlp_ops"]: + out = layers[i](out, **m_splits) + else: + out = layers[i](out, is_first_microbatch=False, **m_splits) + out = sync_function(out) + out.sum().backward() + torch.cuda.synchronize() + del out, inp, layers + + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_memory(self, layer_type, recipe): + Utils.memory_leak_check() + + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() + ): + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + + offload_ctx, sync_function = get_cpu_offload_context( + enabled=True, + num_layers=1, + model_layers=2, + offload_activations=True, + offload_weights=False, + ) + recipe_ctx = Utils.create_recipe_ctx(recipe) + layer = Utils.create_layer(layer_type) + inp = Utils.create_tensor(None) + + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + + # Ops-based layers don't support is_first_microbatch parameter + is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"] + + with recipe_ctx(): + if is_ops_layer: + out = layer(inp, **m_splits) + else: + out = layer(inp, is_first_microbatch=True, **m_splits) + out.sum().backward() + + del inp + init_cuda_memory = Utils.get_cuda_memory_mb() + + # run layer without offload + inp = Utils.create_tensor(None) + with recipe_ctx(): + if is_ops_layer: + out = layer(inp, **m_splits) + else: + out = layer(inp, is_first_microbatch=False, **m_splits) + with recipe_ctx(): + out = out + 1 + del inp + cuda_memory_no_offload = Utils.get_cuda_memory_mb() + + out.sum().backward() + # run layer with offload + inp = Utils.create_tensor(None) + with offload_ctx, recipe_ctx(): + if is_ops_layer: + out = layer(inp, **m_splits) + else: + out = layer(inp, is_first_microbatch=False, **m_splits) + out = sync_function(out) + with offload_ctx, recipe_ctx(): + out = out + 1 + out = sync_function(out) + del inp + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb() + + # This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer. + # It helps catch cases where an offloaded tensor still has a live pointer, which would + # cause an unnecessary copy to the CPU and prevent GPU memory from being released. + assert Utils.get_cuda_memory_mb() + offloaded_memory_cpu == pytest.approx( + cuda_memory_no_offload, 0.1 + ) + out.sum().backward() + + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_manual_synchronization(self, recipe, layer_type): + Utils.memory_leak_check() + + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() ): - # Modules do not deallocate FP8 transpose for weights - return 2 * param_elements / 1024**2 - return param_elements / 1024**2 + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + + offload_ctx, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, + model_layers=6, + offload_activations=True, + manual_synchronization=True, + ) + layer_1 = Utils.create_layer(layer_type) + layer_2 = Utils.create_layer(layer_type) + inp1 = Utils.create_tensor(None) + inp2 = Utils.create_tensor(None) - # MXFP8 caches one data byte per element and one scale byte per 32 - # elements - if quantization_recipe.mxfp8(): - if model_name not in ("linear_op", "layernorm_mlp_ops"): - # Modules do not deallocate column-wise MXFP8 data for weights - return 2 * param_elements * (1 + 1 / 32) / 1024**2 - return param_elements * (1 + 1 / 32) / 1024**2 + recipe_ctx = Utils.create_recipe_ctx(recipe) - raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + + init_cuda_memory = Utils.get_cuda_memory_mb() + + # 1 fwd + with offload_ctx, recipe_ctx(): + out_1 = layer_1(inp1, **m_splits) + out_1 = sync_function(out_1) + + with offload_ctx, recipe_ctx(): + out_2 = layer_2(inp2, **m_splits) + out_2 = sync_function(out_2) + + mark_not_offload(out_1, out_2) + + del inp1, inp2 + + memory_before_offload = Utils.get_cuda_memory_mb() + manual_controller.start_offload_layer(0) + manual_controller.release_activation_forward_gpu_memory(0) + manual_controller.start_offload_layer(1) + manual_controller.release_activation_forward_gpu_memory(1) + memory_after_offload = Utils.get_cuda_memory_mb() + assert memory_after_offload + EPSILON < memory_before_offload + + manual_controller.start_reload_layer(0) + manual_controller.start_reload_layer(1) + + memory_after_reload = Utils.get_cuda_memory_mb() + assert memory_after_reload == pytest.approx(memory_before_offload, 0.1) + + out_1.sum().backward() + out_2.sum().backward() + + @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("use_cuda_graphs", [True, False]) + @pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False]) + @pytest.mark.parametrize("backend", ["FlashAttention", "FusedAttention", "UnfusedAttention"]) + def test_numerics( + self, + recipe, + layer_type, + use_cuda_graphs, + backend, + retain_pinned_cpu_buffers, + ): + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() + ): + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + recipe_ctx = Utils.create_recipe_ctx(recipe) -def _measure_cached_memory( - modules: Iterable[torch.nn.Module], - quantization_recipe: Optional[recipe.Recipe], - cpu_offload: bool, -) -> float: - """Measure the growth in allocated GPU memory in MiB after a model forward pass. + if use_cuda_graphs and not retain_pinned_cpu_buffers: + pytest.skip( + "Cuda graphs are not yet supported with cpu offloading when" + " retain_pinned_cpu_buffers is False." + ) - Memory measurement excludes the input and output tensors. + if backend == "FusedAttention" and use_cuda_graphs: + pytest.skip( + "Fused attention + cuda graphs is temporarily broken, not because of cpu offloading" + ) - """ + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" - # Reset memory - gc.collect() - torch.cuda.empty_cache() + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + elif backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + elif backend == "UnfusedAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" - # Context and sync function for CPU offloading - if cpu_offload: - offload_context, sync_function = te.get_cpu_offload_context( + offload_ctx, sync_function = get_cpu_offload_context( enabled=True, - num_layers=len(modules), - model_layers=len(modules) + 1, + num_layers=1, + model_layers=2, offload_activations=True, offload_weights=False, + retain_pinned_cpu_buffers=retain_pinned_cpu_buffers, ) - else: - offload_context = contextlib.nullcontext() - sync_function = lambda x: x - - # Forward pass, with dummy step to trigger offload for last module - inp = _make_input() - tensor = inp - memory_before_forward = torch.cuda.memory_allocated() / (1024**2) - for module in modules: - with te.autocast( - enabled=quantization_recipe is not None, recipe=quantization_recipe - ), offload_context: - tensor = module(tensor) - tensor = sync_function(tensor) - with offload_context: - tensor = tensor.clone() - tensor = sync_function(tensor) - memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) - - # Backward pass - tensor.sum().backward() - torch.cuda.synchronize() - - # Memory usage in MiB - return memory_after_forward - memory_before_forward - - -@pytest.mark.parametrize("quantization_recipe", quantization_recipes) -@pytest.mark.parametrize("model_name", model_types.keys()) -def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: - """Check that CPU offloading runs and has expected memory usage.""" - - # Construct model - modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] - if model_name in ["multihead_attention", "transformer_layer"]: - available_backends, *_ = get_available_attention_backends( - model_config["small"], - qkv_dtype=torch.bfloat16, - qkv_layout="sbhd_sbhd_sbhd", + + class Callable(torch.nn.Module): + def __init__(self, offload_ctx=None, sync_function=None): + super().__init__() + self.layers = torch.nn.ModuleList( + [Utils.create_layer(layer_type) for _ in range(2)] + ) + self.offload_ctx = offload_ctx + self.sync_function = sync_function + + def forward(self, x): + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"] + for layer in self.layers: + with self.offload_ctx, recipe_ctx(): + if is_ops_layer: + x = layer(x, **m_splits) + else: + x = layer(x, is_first_microbatch=False, **m_splits) + if self.sync_function is not None: + x = self.sync_function(x) + return x + + callable_offload = Callable(offload_ctx=offload_ctx, sync_function=sync_function) + callable_no_offload = Callable(offload_ctx=contextlib.nullcontext(), sync_function=None) + + # copy parameters + for param_offload, param_no_offload in zip( + callable_offload.parameters(), callable_no_offload.parameters() + ): + param_offload.data.copy_(param_no_offload.data) + + x = Utils.create_tensor(None) + + if use_cuda_graphs: + callable_offload = te.make_graphed_callables( + callable_offload, + (x,), + enabled=recipe is not None, + recipe=(Utils.create_recipe_ctx(recipe) if recipe is not None else None), + ) + + # warm up (for example to compute sf for delayed scaling) + for _ in range(4): + out = callable_offload(x) + out.sum().backward() + out = callable_no_offload(x) + out.sum().backward() + + callable_offload.zero_grad(set_to_none=True) + out_offload = callable_offload(x) + out_offload.sum().backward() + + # save out and gradients + offload_outs = [out_offload] + for param in callable_offload.parameters(): + offload_outs.append(param.detach().clone()) + + torch.cuda.reset_peak_memory_stats() + out_no_offload = callable_no_offload(x) + out_no_offload.sum().backward() + + # collect gradients + no_offload_outs = [out_no_offload] + for param in callable_no_offload.parameters(): + no_offload_outs.append(param.detach().clone()) + + # check if tensors are the same + for i in range(len(offload_outs)): + assert torch.allclose(offload_outs[i], no_offload_outs[i]), f"Error in tensor {i}." + + torch.cuda.synchronize() + + def test_example_from_doc(self): + offload_stream = torch.cuda.Stream() + num_layers = 10 + layers = [Utils.create_layer("transformer_layer") for _ in range(num_layers)] + inp = [Utils.create_tensor(None) for _ in range(num_layers)] + out = [None] * num_layers + cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, + model_layers=num_layers, + manual_synchronization=True, + offload_stream=offload_stream, ) - _, fused_attn_supported, _ = available_backends - if not fused_attn_supported: - pytest.skip("Fused attention backend not available.") - os.environ["NVTE_FLASH_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - - # Warmup - _warmup_model(modules_list, quantization_recipe) - - # Measure cached memory after forward pass - memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) - memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) - - # Check for expected memory usage - assert memory_with_offload < memory_without_offload - memory_from_cached_weights = _estimate_cached_weight_size( - model_name, - modules_list, - quantization_recipe, - ) - assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON + + for i in range(num_layers): + with cpu_offload_context: + out[i] = layers[i].forward(inp[i]) + out[i] = sync_function(out[i]) + manual_controller.start_offload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + manual_controller.release_activation_forward_gpu_memory(i) + + for i in range(num_layers - 1, -1, -1): + # these calls are intended to be done in the backward pass + manual_controller.start_reload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + out[i].sum().backward() diff --git a/tests/pytorch/test_cpu_offloading_v1.py b/tests/pytorch/test_cpu_offloading_v1.py new file mode 100644 index 000000000..8a8e03630 --- /dev/null +++ b/tests/pytorch/test_cpu_offloading_v1.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import contextlib +import gc +import os +from typing import Iterable, Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported +from utils import ModelConfig, get_available_attention_backends + +# Check supported quantization schemes +fp8_available = te.is_fp8_available() +mxfp8_available = te.is_mxfp8_available() + +quantization_recipes: Optional[recipe.Recipe] = [None] +if fp8_available: + quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) + +model_config = { + "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), +} +SIZE = model_config["small"].hidden_size +NUM_HEADS = model_config["small"].num_heads +NUM_LAYERS = model_config["small"].num_layers +EPSILON = model_config["small"].eps + +# Flash attention saves some internal tensor for the backward pass +# that cannot be offloaded to CPU. +assert os.getenv("NVTE_FLASH_ATTN") == "0" + +# CPU offload v1 code path is enabled +assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" + +# Offloading is supported for attention only for fused and flash attention backends, +# so the use of bfloat16 is required. +# +# For the TransformerLayer, activation offloading with dropout is not supported, +# so we set hidden_dropout to 0.0. +model_types = { + "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), + "multihead_attention": lambda: te.MultiheadAttention( + SIZE, NUM_HEADS, params_dtype=torch.bfloat16 + ), + "transformer_layer": lambda: te.TransformerLayer( + SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 + ), + "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + "layernorm_mlp_ops": lambda: te.ops.Sequential( + te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + ), +} + + +def _make_input() -> torch.Tensor: + """Generate random input tensor.""" + return torch.randn( + (128, SIZE, SIZE), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + + +def _warmup_model( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> None: + """Perform forward and backward pass""" + tensor = _make_input() + for module in modules: + with te.autocast( + enabled=quantization_recipe is not None, + recipe=quantization_recipe, + ): + tensor = module(tensor) + tensor.sum().backward() + + +def _estimate_cached_weight_size( + model_name: str, + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> float: + """Calculate the memory (in MiB) needed for weight caching.""" + + # The weight params are cached directly for unquantized compute + if quantization_recipe is None: + return 0 + + # Count number of weight param elements + param_elements = 0 + for module in modules: + for param in module.parameters(): + if param.dim() == 2: + param_elements += param.numel() + + # FP8 tensor-scaling caches one byte per element + if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): + if not is_non_tn_fp8_gemm_supported() and model_name not in ( + "linear_op", + "layernorm_mlp_ops", + ): + # Modules do not deallocate FP8 transpose for weights + return 2 * param_elements / 1024**2 + return param_elements / 1024**2 + + # MXFP8 caches one data byte per element and one scale byte per 32 + # elements + if quantization_recipe.mxfp8(): + if model_name not in ("linear_op", "layernorm_mlp_ops"): + # Modules do not deallocate column-wise MXFP8 data for weights + return 2 * param_elements * (1 + 1 / 32) / 1024**2 + return param_elements * (1 + 1 / 32) / 1024**2 + + raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + + +def _measure_cached_memory( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], + cpu_offload: bool, +) -> float: + """Measure the growth in allocated GPU memory in MiB after a model forward pass. + + Memory measurement excludes the input and output tensors. + + """ + + # Reset memory + gc.collect() + torch.cuda.empty_cache() + + # Context and sync function for CPU offloading + if cpu_offload: + offload_context, sync_function = te.get_cpu_offload_context( + enabled=True, + num_layers=len(modules), + model_layers=len(modules) + 1, + offload_activations=True, + offload_weights=False, + ) + else: + offload_context = contextlib.nullcontext() + sync_function = lambda x: x + + # Forward pass, with dummy step to trigger offload for last module + inp = _make_input() + tensor = inp + memory_before_forward = torch.cuda.memory_allocated() / (1024**2) + for module in modules: + with te.autocast( + enabled=quantization_recipe is not None, recipe=quantization_recipe + ), offload_context: + tensor = module(tensor) + tensor = sync_function(tensor) + with offload_context: + tensor = tensor.clone() + tensor = sync_function(tensor) + memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) + + # Backward pass + tensor.sum().backward() + torch.cuda.synchronize() + + # Memory usage in MiB + return memory_after_forward - memory_before_forward + + +@pytest.mark.parametrize("quantization_recipe", quantization_recipes) +@pytest.mark.parametrize("model_name", model_types.keys()) +def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: + """Check that CPU offloading runs and has expected memory usage.""" + + # Construct model + modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] + if model_name in ["multihead_attention", "transformer_layer"]: + available_backends, *_ = get_available_attention_backends( + model_config["small"], + qkv_dtype=torch.bfloat16, + qkv_layout="sbhd_sbhd_sbhd", + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("Fused attention backend not available.") + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + # Warmup + _warmup_model(modules_list, quantization_recipe) + + # Measure cached memory after forward pass + memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) + memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) + + # Check for expected memory usage + assert memory_with_offload < memory_without_offload + memory_from_cached_weights = _estimate_cached_weight_size( + model_name, + modules_list, + quantization_recipe, + ) + assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 147a85fc2..543055061 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -50,6 +50,13 @@ ) from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_activation_offload, + NVTE_CPU_OFFLOAD_V1, +) +from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded # Import attention utils import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils @@ -737,6 +744,9 @@ def forward( x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] + if is_cpu_offload_enabled(): + start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True) + # get batch_size, max_seqlen and cu_seqlens batch_size, context_len = None, None if inference_params is None: @@ -877,12 +887,7 @@ def forward( fp8_output=fp8_output, ) else: - from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadEnabled, - mark_activation_offload, - ) - - if CPUOffloadEnabled: + if is_cpu_offload_enabled(): mark_activation_offload( query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv ) @@ -1116,6 +1121,9 @@ def forward( nvtx_label = "transformer_engine.FusedAttnFunc.forward" nvtx_range_push(f"{nvtx_label}") + if is_cpu_offload_enabled(): + start_offload(q, k, v, offload_base_tensor=True) + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; # may be different from fp8_meta["recipe"] fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() @@ -1293,12 +1301,7 @@ def forward( # used when some tensors are base tensors and loose the "dtype" attribute ctx.nominal_dtype = out_nominal_dtype - from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadEnabled, - mark_activation_offload, - ) - - if CPUOffloadEnabled: + if is_cpu_offload_enabled() and NVTE_CPU_OFFLOAD_V1: if ctx.fp8: tensor_list = fp8_tensors else: @@ -1309,6 +1312,7 @@ def forward( ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *qkvo_tensors, @@ -1339,27 +1343,26 @@ def forward( ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill - from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadedLayer, - ) - - # If interleaved tensor is offloaded, reloaded tensor will be - # non-interleaved, so we need to modify the QKV layout - # for backward - if CPUOffloadedLayer and CPUOffloadEnabled: - reload_layout = "" - split_list = qkv_layout.split("_") - for split in split_list: - temp_layout = "" - rep_count = 1 - for s in split: - if s.isalpha(): - temp_layout = temp_layout + s - else: - rep_count = int(s) - for _ in range(rep_count): - reload_layout = reload_layout + temp_layout + "_" - ctx.qkv_layout = reload_layout[:-1] + if NVTE_CPU_OFFLOAD_V1: + # If interleaved tensor is offloaded, reloaded tensor will be + # non-interleaved, so we need to modify the QKV layout + # for backward + if is_current_layer_offloaded() and is_cpu_offload_enabled(): + reload_layout = "" + split_list = qkv_layout.split("_") + for split in split_list: + temp_layout = "" + rep_count = 1 + for s in split: + if s.isalpha(): + temp_layout = temp_layout + s + else: + rep_count = int(s) + for _ in range(rep_count): + reload_layout = reload_layout + temp_layout + "_" + ctx.qkv_layout = reload_layout[:-1] + else: + ctx.qkv_layout = qkv_layout else: ctx.qkv_layout = qkv_layout diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 4278820e7..4157e8d3a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1494,14 +1494,6 @@ def forward( fp8_output=fp8_output, ) - from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled - - if CPUOffloadEnabled: - warnings.warn( - "Attention activation Offloading is only implemented" - "with Flash Attention and Fused Attention!" - ) - if use_unfused_attention: allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if checkpoint_core_attention: diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index b3bda677b..2440693df 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -33,6 +33,8 @@ from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb +from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled + # Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast(). # Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" # and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. @@ -971,7 +973,8 @@ def forward( # =========================== # Core attention computation # =========================== - + if is_cpu_offload_enabled(): + start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True) context_layer = self.core_attention( query_layer, key_layer, diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 6edc12620..bfdee3475 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -3,698 +3,748 @@ # See LICENSE for license information. """Functionality for CPU offloading of tensors saved for backward pass.""" -from __future__ import annotations -from contextlib import nullcontext -from typing import Any, Dict, Optional +from __future__ import annotations +import contextlib +from collections import defaultdict +from dataclasses import dataclass, field +import os +import warnings +from typing import Any, Optional import torch - +from torch.autograd.graph import saved_tensors_hooks from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .quantized_tensor import QuantizedTensorStorage -from .tensor.float8_tensor import Float8Tensor - -__all__ = ["get_cpu_offload_context"] +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path +from .quantized_tensor import ( + restore_from_saved, + prepare_for_saving, +) -CPUOffloadEnabled = False -CPUOffloadedLayer = False - - -def mark_activation_offload(*tensors): - """Set the type of the offloading needed for a tensor.""" - if TEDebugState.debug_enabled: - raise RuntimeError("CPU offload is not supported in debug mode.") - for tensor in tensors: - if tensor is None: - continue - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: - tensor.activation_offloading = True - else: - data_tensors = tensor.get_data_tensors() - for tensor in data_tensors: - if tensor is not None: - tensor.activation_offloading = True - # This is a hack to force clear the tensor after it is offloaded. - # It is needed, because .*TensorStorage classes are saved in the ctx, - # and they contain the reference to their data tensors. - tensor.needs_force_clear = True +__all__ = ["get_cpu_offload_context", "mark_not_offload", "start_offload"] +NVTE_CPU_OFFLOAD_V1 = os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" -def is_cpu_offload_enabled() -> bool: - """Check if CPU offloading is currently enabled.""" - return CPUOffloadEnabled +OFFLOAD_SYNCHRONIZER = None -class CpuOffloadSavedTensorHook: - """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. +def is_cpu_offload_enabled(): + """Returns True if CPU offload is enabled.""" + if NVTE_CPU_OFFLOAD_V1: + return v1_code_path.is_cpu_offload_enabled() + return OFFLOAD_SYNCHRONIZER is not None - In this context, the ``on_save_for_backward`` method will be called every time - a tensor is saved for backward (this includes intermediary results saved using - :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but - also those recorded by a PyTorch-defined operation). - The ``on_get_saved_tensors`` method will be called when the backward function - of this op attempts to retrieve the saved tensor from context (this includes - :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the - as input the return value of the ``on_save_for_backward``, and is meant to return - an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of - size, device and element values. +def mark_activation_offload(*tensors): + """Set the type of the offloading needed for a tensor.""" + if NVTE_CPU_OFFLOAD_V1: + v1_code_path.mark_activation_offload(*tensors) - Example: - >>> import torch - >>> from typing import Any - >>> - >>> class DummyHook(CpuOffloadSavedTensorHook): - ... - ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - ... logging.info("On save", tensor) - ... return (tensor,) - ... - ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - ... logging.info("On get", saved_state) - ... tensor, = saved_state - ... return tensor - ... - >>> a = torch.ones(5, requires_grad=True) - >>> b = torch.ones(5, requires_grad=True) * 2 - >>> with DummyHook(): - ... y = a * b - ... - On save tensor([1., 1., 1., 1., 1.], requires_grad=True) - On save tensor([2., 2., 2., 2., 2.], grad_fn=) - >>> y.sum().backward() - On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) - On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) +def mark_not_offload(*tensors: torch.Tensor): + """Marks tensors to prevent them from being offloaded.""" + if NVTE_CPU_OFFLOAD_V1: + return - """ + tensors, tensor_obj = prepare_for_saving(*tensors) - def __init__(self) -> None: - self.inside_context = False + for tensor in tensors: + if tensor is not None: + setattr(tensor, "_TE_do_not_offload", True) - def __enter__(self): - global CPUOffloadEnabled - CPUOffloadEnabled = True + restore_from_saved(tensor_obj, tensors) - self.inside_context = True - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) - def __exit__(self, *args: Any): - global CPUOffloadEnabled - CPUOffloadEnabled = False +def start_offload(*tensors: torch.Tensor, offload_base_tensor: bool = False): + """ + Marks point in on main stream where tensors are fully computed and ready to be offloaded. + If offload_base_tensor is True and the tensor is a view, the base tensor is offloaded + and reloaded - the stride and storage offset of the view are saved and restored after reload. + It is useful when multiple tensors are views of the same base tensor, + for example in MultiHeadAttention for interleaved q, k, v tensors. + """ + if NVTE_CPU_OFFLOAD_V1: + return - self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() + def _mark_tensor_for_offload(t): + if t is None: + return + # Attach an event to mark when the tensor is ready for reload. + t.start_reload_event = torch.cuda.Event() + t.start_reload_event.record(torch.cuda.current_stream()) + if offload_base_tensor and t._base is not None: + setattr(t, "offload_base_tensor", True) - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - """On save for backward.""" - raise NotImplementedError( - "`on_save_for_backward: Callable[[torch.Tensor], Any]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks" - ) + tensors, tensor_obj = prepare_for_saving(*tensors) - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - """On get saved tensor.""" - raise NotImplementedError( - "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks" - ) + for tensor in tensors: + _mark_tensor_for_offload(tensor) + restore_from_saved(tensor_obj, tensors) -class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): - """Context-manager that offloads/recovers tensors through an offload hander. - The hook just offloads/recovers the tensor object to the handler through `tensor_push` - and `tensor_pop` interface. How the offload-handler manages the offloading, recovering - or prefetching timing is transparent to this hook. +@dataclass +class TensorGroup: + """ + TensorGroup is a collection of tensors, events and auxiliary data. + It is used multiple times in the CPU offload code. """ - def __init__( - self, - offload_handler: OffloadHandler, - handler_extra_kwargs: Optional[Dict[str, Any]] = None, - debug: bool = False, - ) -> None: - if handler_extra_kwargs is None: - handler_extra_kwargs = {} - self.debug: bool = debug - self.offload_handler: OffloadHandler = offload_handler - self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs - super().__init__() - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) - return retrieve_identifier - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) - return tensor - + tensor_list: list[torch.Tensor] = field(default_factory=list) + events: list[torch.cuda.Event] = field(default_factory=list) + aux: Any = None -class OffloadHandler: - """A base class for CPU offload-handler.""" - def __init__(self) -> None: - pass +class TensorGroupProcessor: + """ + Suppose there is a tensor group T that needs to be offloaded. + Possibly we can switch T into (T_opt, aux), where T_opt is smaller and easier to offload, + offload T_opt, reload it and then restore T from (T_opt_reloaded, aux). - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - """Tensor push.""" - raise NotImplementedError( - "`tensor_push is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_push." - ) + This class contains static methods that perform these optimizations - for example + deduplication of tensors and restoring duplicates after reload. + """ - def tensor_pop(self, tensor_tag: Any, **kwargs): - """Tensor pop.""" - raise NotImplementedError( - "`tensor_pop is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_pop." - ) + @staticmethod + def tensor_group_process_before_offload(tensor_group: TensorGroup) -> tuple[TensorGroup, Any]: + """ + Call for a tensor group, just before offloading logic. + aux is a dictionary that contains auxiliary data, needed to restore pre-offload state. + """ + aux = {} + tensor_group = TensorGroupProcessor._switch_to_base_tensors(aux, tensor_group) + tensor_group = TensorGroupProcessor._deduplicate_tensors(aux, tensor_group) + return tensor_group, aux -class GroupCommitFunction(torch.autograd.Function): - """this is a dummy op with output identical to input. - However, it is necessary for marking a timepoint for offload handler to - accomplish all synchronizations. Implementing it as a function is necessary - because we need to actions in both forward and backward. - """ + @staticmethod + def tensor_group_process_after_reload(tensor_group: TensorGroup): + """ + Call for a tensor group, just after reload logic. + """ + assert tensor_group.aux is not None + tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group) + tensor_group = TensorGroupProcessor._switch_to_views(tensor_group) + return tensor_group @staticmethod - def forward(ctx, tensor, cpu_offload_handler): - # pylint: disable=missing-function-docstring - cpu_offload_handler.on_group_commit_forward() - ctx.cpu_offload_handler = cpu_offload_handler - # return the identical tensor - return tensor + def _switch_to_base_tensors(aux, tensor_group: TensorGroup) -> TensorGroup: + """ + Changes tensors to base tensors and saves view options in aux. + + It we save multiple tensors which in fact are views of the same base tensor, + this will offload only this one base tensor. It is used for example in + MultiHeadAttention for interleaved q, k, v tensors. + """ + + def _check_if_offload_base_tensor(tensor: torch.Tensor) -> bool: + if getattr(tensor, "offload_base_tensor", False): + return True + if tensor._base is not None: + # If tensor is a view of a tensor and has the same elements, + # but with different strides, we can safely offload the base tensor. + # If tensor is a view on some part of a bigger tensor, + # the decision to offload the base tensor is non-trivial and we do not do it by default. + return tensor._base.numel() == tensor.numel() + return False + + aux["views"] = [] + for tensor_id in range( # pylint: disable=consider-using-enumerate + len(tensor_group.tensor_list) + ): + tensor = tensor_group.tensor_list[tensor_id] + if _check_if_offload_base_tensor(tensor): + aux["views"].append((tensor.shape, tensor.stride(), tensor.storage_offset())) + tensor = tensor._base + assert ( + tensor is not None + ), "Cannot offload base tensor, if the tensor is not a view." + tensor_group.tensor_list[tensor_id] = tensor + else: + aux["views"].append(None) + return tensor_group @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_commit_backward() - return grad_output, None + def _deduplicate_tensors(aux, tensor_group: TensorGroup) -> TensorGroup: + """ + Deduplicate tensors. + """ + dedup_tensors: list[torch.Tensor] = [] + dedup_events: list[torch.cuda.Event] = [] + tensor_to_index: dict[int, int] = {} + aux["original_tensor_ids"] = [] + # If there are several duplicates of the same tensor, with different events, + # we keep only first event - every event is recorded when the tensor is ready to be offloaded, + # so it is the most optimal to use the first event. + for tensor_id, tensor in enumerate(tensor_group.tensor_list): + if id(tensor) in tensor_to_index: + aux["original_tensor_ids"].append(tensor_to_index[id(tensor)]) + else: + tensor_to_index[id(tensor)] = len(dedup_tensors) + dedup_tensors.append(tensor) + dedup_events.append(tensor_group.events[tensor_id]) + aux["original_tensor_ids"].append(tensor_to_index[id(tensor)]) -group_prefetch_offload_commit = GroupCommitFunction.apply + tensor_group.tensor_list = dedup_tensors + tensor_group.events = dedup_events + return tensor_group + @staticmethod + def _restore_tensor_duplicates(tensor_group: TensorGroup) -> TensorGroup: + """ + Restore tensor duplicates. + """ + new_tensor_list = [] + new_events_list = [] + for tensor_id in range(len(tensor_group.aux["original_tensor_ids"])): + original_tensor_id = tensor_group.aux["original_tensor_ids"][tensor_id] + new_tensor_list.append(tensor_group.tensor_list[original_tensor_id]) + new_events_list.append(tensor_group.events[original_tensor_id]) + + tensor_group.tensor_list = new_tensor_list + tensor_group.events = new_events_list + return tensor_group -class SynchronizedGroupOffloadHandler(OffloadHandler): - """Offload Handler that offloads/reloads in a synchronized way. - The device-to-host and host-to-device copying happen in the same stream - as the computation kernels, thus the copying will block computation. + @staticmethod + def _switch_to_views(tensor_group: TensorGroup) -> TensorGroup: + """ + Switch to views - reverse of _switch_to_base_tensors. + """ + for tensor_id, tensor in enumerate(tensor_group.tensor_list): + if tensor_group.aux["views"][tensor_id] is not None: + tensor_group.tensor_list[tensor_id] = tensor.as_strided( + *tensor_group.aux["views"][tensor_id] + ) + return tensor_group + + +class OffloadableLayerState: + """ + Class that manages offloading and reloading of tensors for a single layer. """ def __init__( - self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False - ) -> None: - super().__init__() - - self.num_offload_group = num_offload_group - self.tensor_need_offloading_checker = tensor_need_offloading_checker - self.debug = debug - - self.groupid_reset() - - def groupid_reset(self): - """Groupid reset.""" - # Data structures to label saved tensors and book-keep their cpu copies. - # Currently, on push, create a new cpu tensor and copies; on pop, copies - # the tensor back to gpu and deletes the cpu tensor. - # These will increment whenever `group_commit()` is invoked - self.current_group, self.tensor_count_current_group = (0, 0) - self.torch_tensor_count = 0 - self.tensor_tag_to_state = {} - - def on_group_commit_forward(self): - """On group commit forward.""" - # finishing up with updating current group and tensor count - self.current_group += 1 # increment - self.tensor_count_current_group = 0 # reset - - def on_group_commit_backward(self): - """On group commit backward.""" - self.current_group -= 1 - assert self.current_group >= 0 - - @staticmethod - def offload(src_tensor, pin_memory=True): - """Offload.""" - - cpu_backup = torch.empty( - src_tensor.size(), - dtype=src_tensor.dtype, - layout=src_tensor.layout, - device="cpu", - pin_memory=pin_memory, + self, + offload_stream: torch.cuda.Stream, + retain_pinned_cpu_buffers: bool = False, + ): + self.offload_stream = offload_stream + self.retain_pinned_cpu_buffers = retain_pinned_cpu_buffers + + # There are 3 tensor groups: tensors on gpu before offload, + # tensors on cpu after offload, tensors on gpu after reload. + self.fwd_gpu_tensor_group = TensorGroup() + self.cpu_tensor_group = TensorGroup() + self.bwd_gpu_tensor_group = TensorGroup() + + self.aux: dict[str, Any] = {} + + # State can be one of: not_offloaded, offload_started, + # offload_finished, reload_started. + self.state = "not_offloaded" + + def _validate_state(self, func_name: str, allowed_states: list[str]): + assert ( + self.state in allowed_states + ), f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}" + + def start_offload(self): + """ + Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. + Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. + This event is recorded in the start_offload or push_tensor call. + """ + self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) + self.state = "offload_started" + + self.fwd_gpu_tensor_group, aux = TensorGroupProcessor.tensor_group_process_before_offload( + self.fwd_gpu_tensor_group ) - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) - state = (src_tensor.device, cpu_backup) - return state - - @staticmethod - def reload(state, non_blocking=None, copy_buffer=None): - """Reload.""" - dev, cpu_backup = state - if non_blocking is None: - non_blocking = cpu_backup.is_pinned() - - if copy_buffer is None: - return cpu_backup.to(dev, non_blocking=non_blocking) - - assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" - - copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) + allocate_cpu_buffers = ( + not self.retain_pinned_cpu_buffers or len(self.cpu_tensor_group.tensor_list) == 0 + ) - return copy_buffer + for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list): + assert tensor.is_contiguous() - def tensor_push(self, tensor: torch.Tensor, **kwargs): - """Tensor push.""" - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - assert tensor_tag not in self.tensor_tag_to_state - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( - tensor - ): - state = SynchronizedGroupOffloadHandler.offload(tensor) - self.tensor_tag_to_state[tensor_tag] = state - else: - # will be offloaded together after group commit - self.tensor_tag_to_state[tensor_tag] = tensor + # Wait for the moment the tensor is ready to be offloaded. + self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type] - return tensor_tag + with torch.cuda.stream(self.offload_stream): + if allocate_cpu_buffers: + # empty_like is defined also for QuantizedTensors + offloaded_tensor = torch.empty_like( + tensor, device=torch.device("cpu"), pin_memory=True + ) + self.cpu_tensor_group.tensor_list.append(offloaded_tensor) + else: + assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( + "CPU buffer shape does not match the offloaded tensor shape:" + f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " + " Make sure that tensor shaped do not change between" + " iterations if retain_pinned_cpu_buffers is True." + ) + offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] + offloaded_tensor.copy_(tensor, non_blocking=True) + + # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, + # needed to restore pre-offload state after reload. + self.aux = aux + + self.finish_offload_event = torch.cuda.Event() + self.finish_offload_event.record(self.offload_stream) + + def release_activation_forward_gpu_memory(self): + """ + Release GPU memory of the activations. + Waits for offload to finish - memory needs to be kept alive when GPU->CPU copy is performed. + """ + self._validate_state( + func_name="release_activation_forward_gpu_memory", allowed_states=["offload_started"] + ) + self.state = "offload_finished" + + torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type] + + # GPU memory can be released safely after the offload. + # Notice that the memory needs to be kept alive when GPU->CPU copy is performed. + self.fwd_gpu_tensor_group = TensorGroup() + del self.finish_offload_event + + def start_reload(self): + """ + Start reloading of tensors. + It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. + """ + self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) + self.state = "reload_started" + + self.bwd_gpu_tensor_group = TensorGroup() + for tensor in self.cpu_tensor_group.tensor_list: + + # Notice that reloaded tensor is allocated on main stream, + # not offloaded stream. It is because PyTorch memory allocator + # cannot move tensors from pool of one stream to another without + # calling cudaFree and cudaMalloc again. + + # empty_like is defined also for QuantizedTensors. + reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) + self.offload_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.offload_stream): + reloaded_tensor.copy_(tensor, non_blocking=True) + + reload_tensor_event = torch.cuda.Event() + reload_tensor_event.record(self.offload_stream) + self.bwd_gpu_tensor_group.events.append(reload_tensor_event) + self.bwd_gpu_tensor_group.tensor_list.append(reloaded_tensor) + + self.bwd_gpu_tensor_group.aux = self.aux + self.bwd_gpu_tensor_group = TensorGroupProcessor.tensor_group_process_after_reload( + self.bwd_gpu_tensor_group + ) - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - assert tensor_tag in self.tensor_tag_to_state - state = self.tensor_tag_to_state.pop(tensor_tag) - if isinstance(state, tuple): - tensor = SynchronizedGroupOffloadHandler.reload(state) - else: - tensor = state + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + """ + It is called when a tensor is saved for backward pass. + + If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. + If tensor is not offloaded, returns the tensor itself. + """ + self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) + + if self._check_if_offload(tensor): + self.fwd_gpu_tensor_group.tensor_list.append(tensor) + # The group is processed and offloaded at the end of the forward pass of current layer. + # To enable offloading of tensors faster we use self.offload_stream and record + # the events when the tensors are ready to be offloaded. + # It means that we do not need to wait to the end of current layer to start offloading. + if hasattr(tensor, "start_reload_event"): + self.fwd_gpu_tensor_group.events.append(tensor.start_reload_event) + else: + self.fwd_gpu_tensor_group.events.append(torch.cuda.Event()) + self.fwd_gpu_tensor_group.events[-1].record(torch.cuda.current_stream()) + return len(self.fwd_gpu_tensor_group.tensor_list) - 1 return tensor + def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + """ + It is called when a tensor is used in backward pass. + Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. + """ + self._validate_state( + func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] + ) -class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): - """Compared to synchronize, this uses more memory because of the buffer but - achieves better performance due to the overlapping. D2h and h2d copying are - completely hidden behind computation if computation time of a layer is longer - than host-device communication time. Bulk offloading with delay and bulk reloading - with prefetch are implemented.""" + # 1. tensor not offloaded + if isinstance(tensor_or_tensor_id, torch.Tensor): + return tensor_or_tensor_id + # 2. the layer was not offloaded at all + if self.state == "not_offloaded": + return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] + + # 3. the layer was offloaded + assert self.state == "reload_started" + # wait for the tensor to be reloaded + torch.cuda.current_stream().wait_event( + self.bwd_gpu_tensor_group.events[tensor_or_tensor_id] + ) + return self.bwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] + + def release_all_memory(self): + """Release all gpu and cpu memory the state stored. Is called after the backward pass.""" + self.fwd_gpu_tensor_group = TensorGroup() + if not self.retain_pinned_cpu_buffers: + self.cpu_tensor_group = TensorGroup() + self.bwd_gpu_tensor_group = TensorGroup() + self.state = "not_offloaded" + + def _check_if_offload(self, t: torch.Tensor) -> bool: + """ + Check if tensor needs to be offloaded. + """ + if ( + not isinstance(t, torch.nn.Parameter) + and not getattr(t, "_TE_do_not_offload", False) + and not isinstance(t, torch._subclasses.FakeTensor) + and t.device.type == "cuda" + ): + if not t.is_contiguous() and not getattr(t, "offload_base_tensor", False): + warnings.warn( + "Tried to offload non-contiguous tensor, which is not supported. Offload of" + " this tensor will be skipped." + ) + return False + + return True + return False + + def get_offloaded_total_size_mb(self) -> float: + """ + Get total size of offloaded tensors in MB, used only for testing. + """ + + def get_tensor_size_mb(tensor): + if tensor is None: + return 0 + if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): + return sum(get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + return tensor.numel() * tensor.element_size() / (1024**2) + + total_size = 0 + for tensor in self.cpu_tensor_group.tensor_list: + total_size += get_tensor_size_mb(tensor) + return total_size + + +class OffloadSynchronizer: + """ + Base class responsible for synchronizing offloading and reloading of tensors for multiple layers. + In base class we only track layer number and + create OffloadableLayerState instances for all layers, but do not start offloading or reloading. + """ def __init__( self, - num_offload_group, # must be <= actual number of groups (number of commits) - num_model_group, - tensor_need_offloading_checker=(lambda t: True), - double_buffering=False, - debug=False, - ) -> None: - super().__init__( - num_offload_group=num_offload_group, - tensor_need_offloading_checker=tensor_need_offloading_checker, - debug=debug, - ) - # Number of layers in the model - self.num_layers = num_model_group - # Data Structure to maintain reference to activation tensors - self.tensor_tag_to_buf = {} - # Data structure to hold the FP8/MXFP8 tensor objects - self.fp8_tensor_object_map = {} - self.float8_transpose_cache_valid = {} - self.dereferencing_list = [] - # Tracking the number of layers offloaded - self.offloaded_group_count = 0 - # Core data structure that decides the window for offloading - self.layer_window_map = {} - - # Data structures fo double buffered reloading - self.double_buffering = double_buffering - self.reload_double_buffer = [[], []] - self.double_buffer_created = False - - # Logic to make offloading load balance across computation - # for optimal CPU/GPU interconnect usage - constant = 0 - for i in range(self.num_offload_group): - self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 - if i < (self.num_layers % self.num_offload_group): - self.layer_window_map[i] += i + 1 - constant = i + 1 - else: - self.layer_window_map[i] += constant - - # allocate streams and events for synchronization - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() - - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - global CPUOffloadedLayer - - torch_stray_tensor = isinstance( - tensor, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), + num_layers: int, + retain_pinned_cpu_buffers: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, + ): + self.num_layers = num_layers + self.offload_stream = offload_stream if offload_stream is not None else torch.cuda.Stream() + + self.layer_states = { + i: OffloadableLayerState(self.offload_stream, retain_pinned_cpu_buffers) + for i in range(num_layers) + } + + self.num_of_fwds = None + self.previous_bwd_layer_id = None + self.current_layer_id = None + + def fwd_step(self) -> int: + """ + Invoked before each layer forward. + """ + if self.num_of_fwds in [None, self.num_layers - 1]: + # reset the offload synchronizer + self.num_of_fwds = 0 + else: + self.num_of_fwds += 1 + self.current_layer_id = self.num_of_fwds + return self.current_layer_id + + def bwd_step(self, layer_num: int): + """ + Invoked before each layer backward. + """ + if self.previous_bwd_layer_id is not None: + self.layer_states[self.previous_bwd_layer_id].release_all_memory() + self.previous_bwd_layer_id = layer_num + self.current_layer_id = layer_num + + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + """Default push tensor method""" + return self.layer_states[self.num_of_fwds].push_tensor(tensor) + + def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + """Default pop tensor method""" + return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id) + + def finish_part_of_bwd(self): + """ + We need to release memory of backward - this call does that. + It needs to be invoked after every backward pass - there may be + more than one in pipeline parallelism. + + It is needed, because call bwd_step is invoked before each layer backward, + but we need to release memory after the backward pass is finished. + """ + if self.previous_bwd_layer_id is not None: + self.layer_states[self.previous_bwd_layer_id].release_all_memory() + self.previous_bwd_layer_id = None + + def get_offloaded_total_size_mb(self) -> float: + """ + Get total size of offloaded tensors in MB, used only for testing. + """ + return sum( + self.layer_states[layer_id].get_offloaded_total_size_mb() + for layer_id in self.layer_states ) - is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage) - - if not torch_stray_tensor: - - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - - assert tensor_tag not in self.tensor_tag_to_state - if is_quantized_tensor: - tensor_list, _ = tensor.prepare_for_saving() - - self.tensor_tag_to_state[tensor_tag] = [] - self.tensor_tag_to_buf[tensor_tag] = [] - - # Added support for de-duplicating FP8 param tensors - for _, value in self.fp8_tensor_object_map.items(): - if tensor is value: - self.dereferencing_list.append(tensor_tag) - break +class DefaultOffloadSynchronizer(OffloadSynchronizer): + """ + Default implementation of OffloadSynchronizer, + intended to be used in standard training workloads - with multiple forwards + and multiple backwards. + """ - self.fp8_tensor_object_map[tensor_tag] = tensor - if isinstance(tensor, Float8Tensor): - self.float8_transpose_cache_valid[tensor_tag] = getattr( - tensor, "_transpose_invalid" - ) + def __init__( + self, + num_layers: int, + num_offloaded_layers: int | None = None, + retain_pinned_cpu_buffers: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, + ): + super().__init__(num_layers, retain_pinned_cpu_buffers, offload_stream) + + # map of layers to bool meaning if layer needs to be offloaded + self.offload_layer_map: dict[int, bool] = {} + + # num_layer: int -> list of layers that need to finish offload by this moment + self.finish_offload_map: defaultdict[int, list[int]] = defaultdict(list) + # num_layer: int -> list of layers that need to start reload in this moment + self.start_reload_map: defaultdict[int, list[int]] = defaultdict(list) + + self._init_offload_synchronization_dicts(num_offloaded_layers) + + def _init_offload_synchronization_dicts(self, num_offloaded_layers: int): + """ + If synchronization dictionary is not provided, the number of offloaded layers is used to initialize + offload_layer_map, finish_offload_map and start_reload_map. + + The aim is to minimize memory usage by the end of the forward pass. + + The optimal strategy for that is to offload layers 0, ..., num_offloaded_layers - 1. + For layer i offload needs to finish before num_layers - num_offloaded_layers + i. + For layer i reload needs to start after num_layers - num_offloaded_layers + i. + + This ensures that - if all layers have memory footprint of T - then peak memory usage of saving activations is + (num_layers - num_offloaded_layers) * T. + """ + for layer_id in range(self.num_layers): + if layer_id < num_offloaded_layers: + self.offload_layer_map[layer_id] = True + self.finish_offload_map[self.num_layers - num_offloaded_layers + layer_id].append( + layer_id + ) + self.start_reload_map[self.num_layers - 1 - num_offloaded_layers + layer_id].append( + layer_id + ) else: - tensor_list = [tensor] - - for t in tensor_list: - if is_quantized_tensor: - self.tensor_tag_to_state[tensor_tag].append(t) - else: - self.tensor_tag_to_state[tensor_tag] = t - - if ( - self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(t) - ): - if is_quantized_tensor: - self.tensor_tag_to_buf[tensor_tag].append(t) - # Need to clear the internal data reference for the quantized tensors - tensor.clear() - else: - self.tensor_tag_to_buf[tensor_tag] = t - - # Needed to differentiate non offloaded layer's attention - # QKV layout of attention of non-offloaded layer needs - # to be modified while reloading - CPUOffloadedLayer = True - else: - tensor_tag = (-1, self.torch_tensor_count) - self.torch_tensor_count += 1 - self.tensor_tag_to_state[tensor_tag] = tensor + self.offload_layer_map[layer_id] = False - return tensor_tag + def fwd_step(self) -> int: + """ + Invoked before each layer forward. + """ + super().fwd_step() + if self.offload_layer_map.get(self.current_layer_id - 1, False): + self.layer_states[self.current_layer_id - 1].start_offload() - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - global CPUOffloadedLayer + for layer in self.finish_offload_map[self.current_layer_id]: + self.layer_states[layer].release_activation_forward_gpu_memory() + return self.current_layer_id - assert tensor_tag in self.tensor_tag_to_state - tensor = self.tensor_tag_to_state.pop(tensor_tag) + def bwd_step(self, layer_num: int): + """ + Invoked before each layer backward. + """ + super().bwd_step(layer_num) - # Handling the quantized tensor case specially here - if isinstance(tensor, list): - # If it's a duplicated tensor, we don't need to locally - # write back a tensor as it would already be written - if tensor_tag in self.dereferencing_list: - self.dereferencing_list.remove(tensor_tag) - else: - self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) - tensor = self.fp8_tensor_object_map.pop(tensor_tag) + for layer in self.start_reload_map[layer_num]: + self.layer_states[layer].start_reload() - if self.double_buffering: - tensor._do_not_clear = True - self.tensor_tag_to_buf.pop(tensor_tag, None) - # the tensor should have been copied back in on_group_commit_backward() - # which invokes bulk_reload_group. - assert not isinstance(tensor, tuple) - return tensor +class ManualOffloadSynchronizer(OffloadSynchronizer): + """ + Manual implementation of OffloadSynchronizer, + all synchronization is done manually by the user by using + one of the following methods: + - start_offload_layer + - release_activation_forward_gpu_memory + - start_reload_layer + + This implementation is intended to be used in more complex trainigs workflows. + It is useful for example in pipeline parallelism. + """ - def bulk_offload_group(self, group_to_offload): - """Bulk offload group.""" - with torch.cuda.stream(self.d2h_stream): - for tensor_tag, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_tag - if group_id == group_to_offload: - assert not isinstance(state, tuple) - - is_quantized_tensor = isinstance(state, list) - - if is_quantized_tensor: - tensor_list = state - self.tensor_tag_to_state[tensor_tag] = [] - else: - tensor_list = [state] - - for tensor_on_device in tensor_list: - # `tensor_offloaded` is a hacky way of dealing with columnwise-only - # quantized tensors for CPU offloading. The complication is due to - # the `rowwise_data` being `None`. The offloading checker incorrectly - # returns `False` and the entire `state` ([None, columnwise_tensor]) - # is added to the tensor tag state dict. A better design would change - # how quantized tensors are kept track of in the offload handler. - # Currently at every stage it is ensured that a quantized tensor is a - # list whereas a non-quantized tensor is standalone object, which is - # not good! TODO(@sanandaraj5597) - tensor_offloaded = False - # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker(tensor_on_device): - tensor_offloaded = True - state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) - if is_quantized_tensor: - if tensor_offloaded: - self.tensor_tag_to_state[tensor_tag].append(state) - else: - self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) - else: - self.tensor_tag_to_state[tensor_tag] = state - - def synchronize_on_group_commit_forward(self, current_group): - """Synchronize on group commit forward.""" - global CPUOffloadedLayer - - # For the first group, kickstart the offload after we have - # the first compute completion - if current_group == 0: - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - - if not self.double_buffer_created: - # Creating the first copy of double buffer for tensors that are offloaded - for tensor_tag, buf in self.tensor_tag_to_buf.items(): - if isinstance(buf, list): - for b in buf: - self.reload_double_buffer[0].append( - torch.empty_like(b) if self.double_buffering else None - ) - else: - self.reload_double_buffer[0].append( - torch.empty_like(buf) if self.double_buffering else None - ) - - self.bulk_offload_group(current_group) - - # Window map data structure helps us synchronize based on number - # of layers offloaded - if self.layer_window_map[self.offloaded_group_count] == current_group: - - # Stream synchronization both ways - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.d2h_stream) - - # Time to free the activation memory after usage - for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): - if tensor_tag[0] == self.offloaded_group_count: - if hasattr(tensor_buf, "needs_force_clear"): - # Need to clear activation tensor - sometimes references persist in the code. - # This is the case for example with the Float8TensorStorage class, - # which is saved directly inside the ctx while its internal tensors are - # saved inside save_for_backward. - tensor_buf.data = torch.Tensor() - # Release the pointer to the tensor - self.tensor_tag_to_buf[tensor_tag] = None - - # Time to offload the next group - if self.offloaded_group_count < (self.num_offload_group - 1): - self.bulk_offload_group(self.offloaded_group_count + 1) - - # Increment the offload group count to keep track - self.offloaded_group_count += 1 - - if current_group == (self.num_offload_group - 1): - CPUOffloadedLayer = False - - if not self.double_buffer_created: - # Creating second copy of double buffer for tensors that are offloaded - if current_group == (self.num_layers - 1): - for buf in self.reload_double_buffer[0]: - self.reload_double_buffer[1].append( - torch.empty_like(buf) if self.double_buffering else None - ) - self.double_buffer_created = True - - def on_group_commit_forward(self): - """This function will cause host device synchronization""" - # handle synchronization events - self.synchronize_on_group_commit_forward(self.current_group) - - super().on_group_commit_forward() - - def bulk_reload_group(self, group_to_reload): - """Bulk reload group.""" - assert group_to_reload < self.num_offload_group - - buffer_idx = 0 - double_buffer_idx = group_to_reload % 2 - - main_stream = torch.cuda.current_stream() - - with torch.cuda.stream(self.h2d_stream): - # move back tensors - for tensor_label, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_to_reload: - - if isinstance(state, tuple): - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] - else: - with torch.cuda.stream(main_stream): - reload_buffer = torch.empty_like( - state[1], device=torch.cuda.current_device() - ) - - recovered_tensor = SynchronizedGroupOffloadHandler.reload( - state, True, reload_buffer - ) - buffer_idx = buffer_idx + 1 - self.tensor_tag_to_state[tensor_label] = recovered_tensor - elif isinstance(state, list): - tensor_list = [] - for state_tuple in state: - - if isinstance(state_tuple, tuple): - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][ - buffer_idx - ] - else: - with torch.cuda.stream(main_stream): - reload_buffer = torch.empty_like( - state_tuple[1], device=torch.cuda.current_device() - ) - - tensor_list.append( - SynchronizedGroupOffloadHandler.reload( - state_tuple, - True, - reload_buffer, - ) - ) - buffer_idx = buffer_idx + 1 - else: - tensor_list.append(state_tuple) - - # No need to write back the duplicated tensor againn - # to the same location, this check ensures that - if tensor_label in self.dereferencing_list: - self.dereferencing_list.remove(tensor_label) - else: - _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved( - tensor_list - ) - - if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): - self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( - self.float8_transpose_cache_valid.pop(tensor_label) - ) - - self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( - tensor_label - ) - - def on_group_commit_backward(self): - # first decrement the current group. - # after last commit in forward, the group will +1; in backward it -1. - # Finally it should be decremented to 0. - self.current_group -= 1 - assert self.current_group >= 0 - - # Layer window data structure helps us to reload at right times - if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: - - # Stream synchronization both ways - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.h2d_stream) - - # Time to reload the next group - self.bulk_reload_group(self.offloaded_group_count - 1) - - # Decrease the offloading group counter - self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 - - # Last group computation needs to wait till all the reloads complete - if self.current_group == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) - self.offloaded_group_count = 0 + def start_offload_layer(self, layer_id: int): + """ + Start offloading of the layer. + Each tensor GPU->CPU copy is done asynchronously on the offload stream. + Start of each copy is started after tensor_push() is called on the current stream. + """ + self.layer_states[layer_id].start_offload() + + def release_activation_forward_gpu_memory(self, layer_id: int): + """ + Release memory of the activations of the layer. + It waits for the offload of the layer to finish. + """ + self.layer_states[layer_id].release_activation_forward_gpu_memory() + + def start_reload_layer(self, layer_id: int): + """ + Start reloading of the layer. + Each tensor reload is awaited to finish before tensor_pop() for that tensor is called on the current stream. + """ + self.layer_states[layer_id].start_reload() def get_cpu_offload_context( enabled: bool = False, - num_layers: int = 1, + num_layers: Optional[int] = 1, model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = False, - double_buffering: bool = False, + double_buffering: bool = False, # pylint: disable=unused-argument + manual_synchronization: bool = False, + retain_pinned_cpu_buffers: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, ): """ - This function returns the CPU Offload context and the synchronizer function that needs to be - used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + CPU Offloading feature for seqeuences of layers. Can be used for arbitrary layers, not necessarily + for these provided by the TE. Usage: .. code-block:: python - cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + cpu_offload_context, sync_function = get_cpu_offload_context(...) - with cpu_offload_context: - te_layer.forward(inp_tensor) - cpu_offload_synchronizer() + for _ in range(num_layers): + with cpu_offload_context: + x = layers[i].forward(x) + x = sync_function(x) Parameters ---------- enabled: bool, default = `False` When set to True, CPU Offloading functionality is enabled. num_layers: int, default = 1 - Determines the number of transformer layers - you want to offload activations/weights for. + Determines the number of layers + you want to offload activations/weights for. model_layers: int, default = 1 - Number of layers in the model that will be used under this context. + Number of layers in the model that will be used under this context. offload_activations: bool, default = `True` - When set to `True`, offloads the activations for the TE layer. + Deprecated. offload_weights: bool, default = `True` - When set to `True`, offloads the weights for the TE layer. + Deprecated. double_buffering: bool, default = `False` - When set to `True`, uses double buffering for offloading. + Deprecated. + retain_pinned_cpu_buffers: bool, default = `False` + If True, the pinned CPU buffers are retained after offloading + and reused for the next iteration. It is useful for cuda graphs capture. + manual_synchronization: bool, default = `False` + If True, the synchronization is done manually by the user. + Additional argument manual_controller is returned. See more in manual control section. + offload_stream: torch.cuda.Stream, default = `None` + If provided, the offload stream is used for offloading and reloading. + Otherwise, a new stream is allocated internally. It can be other than None + only if manual_synchronization is True. + + Manual synchronization + ---------- + By default, layers are offloaded/reloaded asynchronously + with respect to the current forward/backward stream with predefined synchronization, + to ensure that activation memory usage is equal to + `(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer. + + For more control over the offloading and reloading process, you can set `manual_synchronization=True`. + In this case, an additional argument, `manual_controller`, is returned. + + The `manual_controller` provides the following methods: + - `start_offload_layer(layer_id: int)` + - `release_activation_forward_gpu_memory(layer_id: int)` + - `start_reload_layer(layer_id: int)` + + If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded. + If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream. + + Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored. + To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`. + This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded. + + The `start_reload_layer()` method is used to start reloading a layer. + Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream. + + You can provide an `offload_stream` to be used for offload and reload operations. + This allows for more detailed synchronization, such as delaying the start of offloading. + + Example: + .. code-block:: python + offload_stream = torch.cuda.Stream() + cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream) + + for i in range(num_layers): + with cpu_offload_context: + out[i] = layers[i].forward(inp[i]) + out[i] = sync_function(out[i]) + manual_controller.start_offload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + manual_controller.release_activation_forward_gpu_memory(i) + + for i in range(num_layers - 1, -1, -1): + manual_controller.start_reload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + out[i].sum().backward() + + V1 code path + ---------- + If you want to use the v1 code path for offloading, + please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1. """ + if NVTE_CPU_OFFLOAD_V1: + return v1_code_path.get_cpu_offload_context( + enabled=enabled, + num_layers=num_layers, + model_layers=model_layers, + offload_activations=offload_activations, + offload_weights=offload_weights, + double_buffering=double_buffering, + ) if not offload_weights and not offload_activations: raise ValueError( @@ -703,8 +753,6 @@ def get_cpu_offload_context( ) if offload_weights: - import warnings - warnings.warn( "Offloading weights is deprecated. Using offload_weights=True does not have any" " effect.", @@ -713,26 +761,100 @@ def get_cpu_offload_context( # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. if not offload_activations: - return nullcontext(), lambda x: x + return contextlib.nullcontext(), lambda x: x - def tensor_need_offloading_checker_activations(tensor): - return hasattr(tensor, "activation_offloading") - - tensor_need_offloading_checker = tensor_need_offloading_checker_activations + if TEDebugState.debug_enabled: + raise RuntimeError("CPU offload is not supported in debug mode.") - cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( - num_offload_group=num_layers, - num_model_group=model_layers, - tensor_need_offloading_checker=tensor_need_offloading_checker, - double_buffering=double_buffering, - ) + if not manual_synchronization: + assert ( + num_layers <= model_layers - 1 + ), "Cannot offload all layers without manual synchronization - last layer is not offloaded." + if num_layers == model_layers - 1: + warnings.warn( + "Offloading num_layers == model_layers - 1 is not recommended, it prevents" + " overlapping of computation and offload/reload." + ) + + assert ( + offload_stream is None or manual_synchronization + ), "offload_stream can be provided only if manual_synchronization is True" + + if manual_synchronization: + offload_synchronizer = ManualOffloadSynchronizer( + model_layers, retain_pinned_cpu_buffers, offload_stream + ) + else: + offload_synchronizer = DefaultOffloadSynchronizer( + model_layers, + num_layers, + retain_pinned_cpu_buffers, + offload_stream, + ) - def group_prefetch_offload_commit_async(tensor): - return group_prefetch_offload_commit(tensor, cpu_offload_handler) + class _CpuOffloadContext(contextlib.ContextDecorator): + def __init__(self): + self.current_layer = None + self.previous_offload_synchronizer = None + self.offload_synchronizer = offload_synchronizer + + self.inside_context = False + + def __enter__(self): + assert ( + self.inside_context is False + ), "Offloading context was entered without synchronization function being called." + self.inside_context = True + self._hooks_ctx = saved_tensors_hooks( + offload_synchronizer.push_tensor, offload_synchronizer.pop_tensor + ) + self._hooks_ctx.__enter__() + global OFFLOAD_SYNCHRONIZER + self.previous_offload_synchronizer = OFFLOAD_SYNCHRONIZER + OFFLOAD_SYNCHRONIZER = offload_synchronizer + self.current_layer = offload_synchronizer.fwd_step() + return self + + def __exit__(self, *args): + self._hooks_ctx.__exit__(*args) + global OFFLOAD_SYNCHRONIZER + OFFLOAD_SYNCHRONIZER = self.previous_offload_synchronizer + self.inside_context = False + + def synchronization_function(self, tensor): + """ + This function is used to catch the backward pass of the model. + """ + assert tensor.requires_grad is True + assert self.current_layer is not None + cur_layer = self.current_layer + assert ( + self.inside_context is False + ), "Synchronization function was called without offloading context being entered." + + def hook(_): + # offload_synchronizer.finish_part_of_bwd needs + # to be called after every backward pass - there may be + # more than one in pipeline parallelism. + torch.autograd.variable.Variable._execution_engine.queue_callback( + offload_synchronizer.finish_part_of_bwd + ) + offload_synchronizer.bwd_step(cur_layer) + + tensor.grad_fn.register_prehook(hook) + return tensor + + cpu_offload_context = _CpuOffloadContext() if enabled: + if manual_synchronization: + return ( + cpu_offload_context, + cpu_offload_context.synchronization_function, + offload_synchronizer, + ) return ( - CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), - group_prefetch_offload_commit_async, + cpu_offload_context, + cpu_offload_context.synchronization_function, ) - return nullcontext(), group_prefetch_offload_commit_async + return contextlib.nullcontext(), lambda x: x diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py new file mode 100644 index 000000000..9f904864a --- /dev/null +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -0,0 +1,743 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch + +from transformer_engine.debug.pytorch.debug_state import TEDebugState +from .quantized_tensor import QuantizedTensorStorage +from .tensor.float8_tensor import Float8Tensor + +__all__ = ["get_cpu_offload_context"] + +CPUOffloadEnabled = False +CPUOffloadedLayer = False + + +def mark_activation_offload(*tensors): + """Set the type of the offloading needed for a tensor.""" + if TEDebugState.debug_enabled: + raise RuntimeError("CPU offload is not supported in debug mode.") + + for tensor in tensors: + if tensor is None: + continue + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + tensor.activation_offloading = True + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + tensor.activation_offloading = True + # This is a hack to force clear the tensor after it is offloaded. + # It is needed, because .*TensorStorage classes are saved in the ctx, + # and they contain the reference to their data tensors. + tensor.needs_force_clear = True + + +def is_cpu_offload_enabled() -> bool: + """Check if CPU offloading is currently enabled.""" + return CPUOffloadEnabled + + +def is_current_layer_offloaded() -> bool: + """Check if current layers is being offloaded.""" + return CPUOffloadedLayer + + +class CpuOffloadSavedTensorHook: + """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. + + In this context, the ``on_save_for_backward`` method will be called every time + a tensor is saved for backward (this includes intermediary results saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). + + The ``on_get_saved_tensors`` method will be called when the backward function + of this op attempts to retrieve the saved tensor from context (this includes + :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the + as input the return value of the ``on_save_for_backward``, and is meant to return + an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of + size, device and element values. + + Example: + + >>> import torch + >>> from typing import Any + >>> + >>> class DummyHook(CpuOffloadSavedTensorHook): + ... + ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + ... logging.info("On save", tensor) + ... return (tensor,) + ... + ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + ... logging.info("On get", saved_state) + ... tensor, = saved_state + ... return tensor + ... + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with DummyHook(): + ... y = a * b + ... + On save tensor([1., 1., 1., 1., 1.], requires_grad=True) + On save tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) + On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + + """ + + def __init__(self) -> None: + self.inside_context = False + + def __enter__(self): + global CPUOffloadEnabled + CPUOffloadEnabled = True + + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) + + def __exit__(self, *args: Any): + global CPUOffloadEnabled + CPUOffloadEnabled = False + + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """On save for backward.""" + raise NotImplementedError( + "`on_save_for_backward: Callable[[torch.Tensor], Any]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """On get saved tensor.""" + raise NotImplementedError( + "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + +class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[Dict[str, Any]] = None, + debug: bool = False, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.debug: bool = debug + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs + super().__init__() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop." + ) + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__( + self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False + ) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + self.debug = debug + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None, copy_buffer=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + + if copy_buffer is None: + return cpu_backup.to(dev, non_blocking=non_blocking) + + assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" + + copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) + + return copy_buffer + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( + tensor + ): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + double_buffering=False, + debug=False, + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + debug=debug, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Data structure to hold the FP8/MXFP8 tensor objects + self.fp8_tensor_object_map = {} + self.float8_transpose_cache_valid = {} + self.dereferencing_list = [] + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + + # Data structures fo double buffered reloading + self.double_buffering = double_buffering + self.reload_double_buffer = [[], []] + self.double_buffer_created = False + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + global CPUOffloadedLayer + + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) + + is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage) + + if not torch_stray_tensor: + + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + + if is_quantized_tensor: + tensor_list, _ = tensor.prepare_for_saving() + + self.tensor_tag_to_state[tensor_tag] = [] + self.tensor_tag_to_buf[tensor_tag] = [] + + # Added support for de-duplicating FP8 param tensors + for _, value in self.fp8_tensor_object_map.items(): + if tensor is value: + self.dereferencing_list.append(tensor_tag) + break + + self.fp8_tensor_object_map[tensor_tag] = tensor + if isinstance(tensor, Float8Tensor): + self.float8_transpose_cache_valid[tensor_tag] = getattr( + tensor, "_transpose_invalid" + ) + else: + tensor_list = [tensor] + + for t in tensor_list: + if is_quantized_tensor: + self.tensor_tag_to_state[tensor_tag].append(t) + else: + self.tensor_tag_to_state[tensor_tag] = t + + if ( + self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(t) + ): + if is_quantized_tensor: + self.tensor_tag_to_buf[tensor_tag].append(t) + # Need to clear the internal data reference for the quantized tensors + tensor.clear() + else: + self.tensor_tag_to_buf[tensor_tag] = t + + # Needed to differentiate non offloaded layer's attention + # QKV layout of attention of non-offloaded layer needs + # to be modified while reloading + CPUOffloadedLayer = True + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + global CPUOffloadedLayer + + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + + # Handling the quantized tensor case specially here + if isinstance(tensor, list): + # If it's a duplicated tensor, we don't need to locally + # write back a tensor as it would already be written + if tensor_tag in self.dereferencing_list: + self.dereferencing_list.remove(tensor_tag) + else: + self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) + tensor = self.fp8_tensor_object_map.pop(tensor_tag) + + if self.double_buffering: + tensor._do_not_clear = True + + self.tensor_tag_to_buf.pop(tensor_tag, None) + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + + is_quantized_tensor = isinstance(state, list) + + if is_quantized_tensor: + tensor_list = state + self.tensor_tag_to_state[tensor_tag] = [] + else: + tensor_list = [state] + + for tensor_on_device in tensor_list: + # `tensor_offloaded` is a hacky way of dealing with columnwise-only + # quantized tensors for CPU offloading. The complication is due to + # the `rowwise_data` being `None`. The offloading checker incorrectly + # returns `False` and the entire `state` ([None, columnwise_tensor]) + # is added to the tensor tag state dict. A better design would change + # how quantized tensors are kept track of in the offload handler. + # Currently at every stage it is ensured that a quantized tensor is a + # list whereas a non-quantized tensor is standalone object, which is + # not good! TODO(@sanandaraj5597) + tensor_offloaded = False + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + tensor_offloaded = True + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + if is_quantized_tensor: + if tensor_offloaded: + self.tensor_tag_to_state[tensor_tag].append(state) + else: + self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) + else: + self.tensor_tag_to_state[tensor_tag] = state + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + global CPUOffloadedLayer + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + if not self.double_buffer_created: + # Creating the first copy of double buffer for tensors that are offloaded + for tensor_tag, buf in self.tensor_tag_to_buf.items(): + if isinstance(buf, list): + for b in buf: + self.reload_double_buffer[0].append( + torch.empty_like(b) if self.double_buffering else None + ) + else: + self.reload_double_buffer[0].append( + torch.empty_like(buf) if self.double_buffering else None + ) + + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + if hasattr(tensor_buf, "needs_force_clear"): + # Need to clear activation tensor - sometimes references persist in the code. + # This is the case for example with the Float8TensorStorage class, + # which is saved directly inside the ctx while its internal tensors are + # saved inside save_for_backward. + tensor_buf.data = torch.Tensor() + # Release the pointer to the tensor + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + if current_group == (self.num_offload_group - 1): + CPUOffloadedLayer = False + + if not self.double_buffer_created: + # Creating second copy of double buffer for tensors that are offloaded + if current_group == (self.num_layers - 1): + for buf in self.reload_double_buffer[0]: + self.reload_double_buffer[1].append( + torch.empty_like(buf) if self.double_buffering else None + ) + self.double_buffer_created = True + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + buffer_idx = 0 + double_buffer_idx = group_to_reload % 2 + + main_stream = torch.cuda.current_stream() + + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + + if isinstance(state, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state[1], device=torch.cuda.current_device() + ) + + recovered_tensor = SynchronizedGroupOffloadHandler.reload( + state, True, reload_buffer + ) + buffer_idx = buffer_idx + 1 + self.tensor_tag_to_state[tensor_label] = recovered_tensor + elif isinstance(state, list): + tensor_list = [] + for state_tuple in state: + + if isinstance(state_tuple, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][ + buffer_idx + ] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state_tuple[1], device=torch.cuda.current_device() + ) + + tensor_list.append( + SynchronizedGroupOffloadHandler.reload( + state_tuple, + True, + reload_buffer, + ) + ) + buffer_idx = buffer_idx + 1 + else: + tensor_list.append(state_tuple) + + # No need to write back the duplicated tensor againn + # to the same location, this check ensures that + if tensor_label in self.dereferencing_list: + self.dereferencing_list.remove(tensor_label) + else: + _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved( + tensor_list + ) + + if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): + self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( + self.float8_transpose_cache_valid.pop(tensor_label) + ) + + self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( + tensor_label + ) + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_cpu_offload_context( + enabled: bool = False, + num_layers: int = 1, + model_layers: int = 1, + offload_activations: bool = True, + offload_weights: bool = False, + double_buffering: bool = False, +): + """ + This function returns the CPU Offload context and the synchronizer function that needs to be + used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + + Usage: + + .. code-block:: python + + cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + + with cpu_offload_context: + te_layer.forward(inp_tensor) + cpu_offload_synchronizer() + + Parameters + ---------- + enabled: bool, default = `False` + When set to True, CPU Offloading functionality is enabled. + num_layers: int, default = 1 + Determines the number of transformer layers + you want to offload activations/weights for. + model_layers: int, default = 1 + Number of layers in the model that will be used under this context. + offload_activations: bool, default = `True` + When set to `True`, offloads the activations for the TE layer. + offload_weights: bool, default = `True` + When set to `True`, offloads the weights for the TE layer. + double_buffering: bool, default = `False` + When set to `True`, uses double buffering for offloading. + + """ + + if not offload_weights and not offload_activations: + raise ValueError( + "CPU Offloading is enabled while it is not " + "mentioned what to offload (weights/activations)" + ) + + if offload_weights: + import warnings + + warnings.warn( + "Offloading weights is deprecated. Using offload_weights=True does not have any" + " effect.", + DeprecationWarning, + ) + + # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. + if not offload_activations: + return nullcontext(), lambda x: x + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor, "activation_offloading") + + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + double_buffering=double_buffering, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + if enabled: + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + return nullcontext(), group_prefetch_offload_commit_async diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index f336a743d..1a56a06da 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -41,7 +41,7 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..cpu_offload import is_cpu_offload_enabled +from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..quantized_tensor import ( @@ -135,6 +135,9 @@ def forward( else: inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + if cpu_offloading: + start_offload(*inputmats) + # Initialize weights weights_fp8: list if fp8: @@ -196,6 +199,9 @@ def forward( for i in range(num_gemms): weight_quantizers[i].calibrate(weights[i]) + if cpu_offloading: + mark_not_offload(*weights_fp8, *weights) + if is_grad_enabled: ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 20a67cba4..4ed3ebb73 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -66,10 +66,15 @@ from ...debug.pytorch.debug_state import TEDebugState from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_not_offload, + mark_activation_offload, +) from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..export import is_in_onnx_export_mode, assert_warmed_up -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpp_extensions import ( general_gemm, @@ -158,6 +163,9 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") + if is_cpu_offload_enabled(): + start_offload(inputmat) + tp_world_size = get_distributed_world_size(tp_group) weight_requires_grad = weight.requires_grad @@ -434,8 +442,14 @@ def forward( nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") if cpu_offloading: + mark_not_offload( + weightmat, + weight, + bias, + ln_weight, + ln_bias, + ) ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. @@ -542,6 +556,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a358ae7dd..c29775c92 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -69,7 +69,12 @@ from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, WeightGradStore -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ..cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_not_offload, + mark_activation_offload, +) from ..quantized_tensor import ( QuantizedTensorStorage, Quantizer, @@ -235,6 +240,8 @@ def forward( ln_weight = cast_if_needed(ln_weight, activation_dtype) if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + if is_cpu_offload_enabled(): + start_offload(inputmat) tp_world_size = get_distributed_world_size(tp_group) backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad @@ -577,6 +584,18 @@ def forward( clear_tensor_data(act_out) act_out = None + if cpu_offloading: + mark_not_offload( + ln_weight, + ln_bias, + fc1_weight_final, + fc1_weight, + fc1_bias, + fc2_weight_final, + fc2_weight, + fc2_bias, + ) + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 46b9dbd85..00b78995f 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -68,7 +68,12 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.utils import is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ..cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_not_offload, + mark_activation_offload, +) from ...debug.pytorch.debug_state import TEDebugState __all__ = ["Linear"] @@ -229,6 +234,9 @@ def forward( else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP inputmat_total = inputmat + + if is_cpu_offload_enabled(): + start_offload(inputmat) nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # ------------------------------------------------------ # Input tensor is ready for GEMM... @@ -417,6 +425,7 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight + mark_not_offload(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 18f7e2031..73b312ec2 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -372,9 +372,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros_like(param, dtype=torch.int16) + data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) else: - data = torch.empty_like(param, dtype=dtype) + data = torch.empty(param.shape, dtype=dtype, device=param.device) if zero_buffer: data.zero_() diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 7d49e3964..c830b19e9 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -9,6 +9,7 @@ import abc import copy import warnings +import math import torch from torch.utils._pytree import tree_map @@ -20,6 +21,11 @@ _stride_from_shape, ) +_quantized_tensor_cpu_supported_ops = ( + torch.ops.aten.empty_like.default, + torch.ops.aten.copy_.default, +) + class QuantizedTensorStorage: r"""Base class for all *TensorStorage classes. @@ -35,7 +41,7 @@ class QuantizedTensorStorage: XTensorStorage should contain all data members needed to implement the functionality of the tensor, while XTensor should only implement the functionality needed - to behave like regular torch.Tensor (liek __torch_dispatch__).""" + to behave like regular torch.Tensor (like __torch_dispatch__).""" _quantizer: Optional[Quantizer] @@ -63,6 +69,12 @@ def update_usage( f"{self.__class__.__name__} class does not implement update_usage function" ) + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement get_usages function" + ) + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( @@ -128,6 +140,7 @@ def prepare_for_saving( t, t_obj = tensor.prepare_for_saving() tensor_list.extend(t) tensor_objects_list.append(t_obj) + return tensor_list, tensor_objects_list @@ -314,6 +327,13 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-a """Returns whether or not given tensor can be quantized""" return True + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the quantizer""" + return { + "rowwise": self.rowwise_usage, + "columnwise": self.columnwise_usage, + } + class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data @@ -325,7 +345,14 @@ class QuantizedTensor(torch.Tensor): """ - def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): + def __new__( + cls, + shape: Iterable[int], + dtype: torch.dtype, + *, + requires_grad: bool = False, + device: Optional[torch.device] = None, + ): # We are assuming only contiguous tensors stride = _stride_from_shape(shape) instance = torch.Tensor._make_wrapper_subclass( @@ -336,7 +363,7 @@ def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: boo dtype=dtype, layout=torch.strided, requires_grad=requires_grad, - device=torch.cuda.current_device(), + device=torch.cuda.current_device() if device is None else device, ) return instance @@ -366,6 +393,9 @@ def detach(self) -> QuantizedTensor: def clear(self): """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement clear function" + ) def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" @@ -407,6 +437,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.copy_.default: dst = args[0] src = args[1] + if ( + isinstance(dst, QuantizedTensor) + and isinstance(src, QuantizedTensor) + and type(dst._quantizer) is type(src._quantizer) + and set(src.get_usages().keys()) == set(dst.get_usages().keys()) + and all( + src.get_usages()[usage] == dst.get_usages()[usage] + for usage in src.get_usages().keys() + ) + ): + + dst_tensors, dst_tensor_obj = dst.prepare_for_saving() + src_tensors, src_tensor_obj = src.prepare_for_saving() + for dst_tensor, src_tensor in zip(dst_tensors, src_tensors): + if dst_tensor is not None: + dst_tensor.copy_(src_tensor, *args[2:], **kwargs) + dst_tensor_obj.restore_from_saved(dst_tensors) + src_tensor_obj.restore_from_saved(src_tensors) + return None + if isinstance(dst, QuantizedTensor): dst.quantize_(src) else: @@ -419,6 +469,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + # Empty like op + if func == torch.ops.aten.empty_like.default: + tensor = args[0] + device = kwargs.get("device", tensor.device) + requires_grad = kwargs.get("requires_grad", tensor.requires_grad) + pin_memory = kwargs.get("pin_memory", False) + usage = tensor.get_usages() + quantizer_usage = tensor._quantizer.get_usages() + tensor._quantizer.set_usage(**usage) + out = tensor._quantizer.make_empty( + shape=tensor.shape, + dtype=tensor.dtype, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + tensor._quantizer.set_usage(**quantizer_usage) + return out + + if func == torch.ops.aten.numel.default: + tensor = args[0] + return math.prod(tensor.size()) + + if func == torch.ops.aten.is_pinned.default: + tensor = args[0] + for t in tensor.get_data_tensors(): + if t is not None: + return func(t) + return False # Or error out? + def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): return arg.dequantize(dtype=arg.dtype) @@ -463,6 +543,16 @@ def maybe_update_inplace(arg, new_arg, schema_arg): def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + + def check_if_cpu(arg): + if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu": + assert ( + func in _quantized_tensor_cpu_supported_ops + ), f"QuantizedTensor on CPU does not support this operation: {func}" + return arg + + args = tree_map(check_if_cpu, args) + # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 8054374c8..8440c14b7 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -214,6 +214,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -229,12 +230,13 @@ def make_empty( data = None scale_inv = None if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_shape = self.get_scale_shape(shape, columnwise=False) scale_inv = torch.empty( scale_shape, dtype=torch.float32, device=device, + pin_memory=pin_memory, ) # Allocate FP8 data transpose if needed @@ -242,13 +244,17 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty( - self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + self.get_columnwise_shape(shape), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( columnwise_scale_shape, dtype=torch.float32, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index eb2ac9a58..7f7195a17 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -101,6 +101,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8Tensor: # Canonicalize tensor attributes @@ -108,16 +109,19 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - transpose_shape = [data.size(-1)] + list(data.shape[:-1]) + transpose_shape = [shape[-1]] + list(shape[:-1]) data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -125,7 +129,7 @@ def make_empty( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -287,6 +291,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8Tensor: # Canonicalize tensor attributes @@ -294,23 +299,26 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - transpose_shape = [data.size(-1)] + list(data.shape[:-1]) + transpose_shape = [shape[-1]] + list(shape[:-1]) data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor return Float8Tensor( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -715,14 +723,22 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return cls.detach(args[0]) if func == torch.ops.aten.clone.default: return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) - if src._transpose is not None or dst._transpose is not None: - dst._create_transpose() + if dst._data is not None: + dst._data.copy_(src._data.detach(), *args[2:], **kwargs) + if dst._scale_inv is not None: + dst._scale_inv.copy_( + src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs + ) + if dst._transpose is not None and not dst._transpose_invalid: + if not src._transpose_invalid: + dst._transpose.copy_(src._transpose, *args[2:], **kwargs) + else: + dst._create_transpose() return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 15e0b86c9..7ca6e3b0d 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -90,6 +90,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> MXFP8Tensor: # Canonicalize tensor attributes @@ -105,24 +106,29 @@ def make_empty( ) # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + scale_inv = torch.empty( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, + ) # Allocate FP8 data transpose if needed columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: - columnwise_data = torch.empty_like(data) + columnwise_data = torch.empty_like(data, pin_memory=pin_memory) columnwise_scale_inv = torch.empty( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -348,11 +354,17 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) if rowwise_matches and columnwise_matches: if dst._rowwise_data is not None: - dst._rowwise_data.copy_(src._rowwise_data.detach()) - dst._rowwise_scale_inv.copy_(src._rowwise_scale_inv.detach()) + dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs) + dst._rowwise_scale_inv.copy_( + src._rowwise_scale_inv.detach(), *args[2:], **kwargs + ) if dst._columnwise_data is not None: - dst._columnwise_data.copy_(src._columnwise_data.detach()) - dst._columnwise_scale_inv.copy_(src._columnwise_scale_inv.detach()) + dst._columnwise_data.copy_( + src._columnwise_data.detach(), *args[2:], **kwargs + ) + dst._columnwise_scale_inv.copy_( + src._columnwise_scale_inv.detach(), *args[2:], **kwargs + ) return dst # FSDP2 related functions. diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 7a5f8858f..31dbcf00a 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Iterable import math -from typing import Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import functools import torch @@ -265,6 +265,7 @@ def make_empty( *, dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, + pin_memory: bool = False, requires_grad: bool = False, ) -> NVFP4Tensor: @@ -288,11 +289,18 @@ def make_empty( scale_inv = None amax_rowwise = None if self.rowwise_usage: - data = torch.empty(self.convert_shape_for_fp4(shape), dtype=torch.uint8, device=device) + data = torch.empty( + self.convert_shape_for_fp4(shape), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, + ) scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty(scale_shape, dtype=torch.uint8, device=device) + scale_inv = torch.empty( + scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device) + amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) # Allocate FP8 data transpose if needed columnwise_data = None @@ -306,12 +314,15 @@ def make_empty( self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device + columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + amax_columnwise = torch.zeros( + 1, dtype=torch.float32, device=device, pin_memory=pin_memory ) - amax_columnwise = torch.zeros(1, dtype=torch.float32, device=device) # Construct FP8 tensor return NVFP4Tensor( @@ -498,6 +509,12 @@ def contiguous( return self raise ValueError("NVFP4Tensor does not support different memory formats!") + def get_usages(self) -> Dict[str, bool]: + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -520,16 +537,20 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) if tensor._rowwise_data is not None: - rowwise_data = data_init_func(tensor._rowwise_data) - rowwise_scale_inv = scale_inv_init_func(tensor._rowwise_scale_inv) - amax_rowwise = torch.zeros_like(tensor._amax_rowwise) + rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs) + rowwise_scale_inv = scale_inv_init_func( + tensor._rowwise_scale_inv, *args[1:], **kwargs + ) + amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs) else: rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None if tensor._columnwise_data is not None: - columnwise_data = data_init_func(tensor._columnwise_data) - columnwise_scale_inv = scale_inv_init_func(tensor._columnwise_scale_inv) - amax_columnwise = torch.zeros_like(tensor._amax_columnwise) + columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs) + columnwise_scale_inv = scale_inv_init_func( + tensor._columnwise_scale_inv, *args[1:], **kwargs + ) + amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs) else: columnwise_data, columnwise_scale_inv, amax_columnwise = ( None, diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index c2d5e8b3f..38d117b2a 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -420,3 +420,10 @@ def update_usage( return return + + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index a31f6a379..8d12c3070 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -225,3 +225,12 @@ def update_usage( if not needs_data_transpose: self._transpose = None self._transpose_invalid = True + + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + usages = {"rowwise": self._data is not None} + if is_non_tn_fp8_gemm_supported(): + usages["columnwise"] = self._data is not None + else: + usages["columnwise"] = self._transpose is not None and not self._transpose_invalid + return usages diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 2cca0829d..e7840d2c4 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -254,3 +254,10 @@ def update_usage( else: self._columnwise_data = None self._columnwise_scale_inv = None + + def get_usages(self) -> Tuple[bool, bool]: + """Get the usage of the tensor""" + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } From 389a6ba4f7be41e21b3db437d9ba23a01a44db1a Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:57:07 -0800 Subject: [PATCH 128/141] [JAX] Use TE quant if TE fused act is disabled (#2374) * Use TE quant if TE fused act is disabled Signed-off-by: Jeremy Berchtold * Keep existing precision Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/activation.py | 43 ++++++++++++++++--- 1 file changed, 36 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index aa84fafd3..e8249de17 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -27,7 +27,7 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_dbias, _quantize_dbias_impl, AmaxScope +from .quantization import _jax_dbias, quantize, quantize_dbias, _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -1268,7 +1268,19 @@ def act_lu( ) act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer, act_params) + act_out = _jax_act_lu(x, activation_type, act_params=act_params) + assert ( + act_out.data.dtype == x.dtype + ), f"JAX activation output dtype {act_out.data.dtype} must match input dtype {x.dtype}" + if quantizer is None: + return act_out + + return quantize( + act_out, + quantizer=quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) # TE/common does not support colwise-only quantization yet if quantizer is not None and quantizer.q_layout.is_colwise_only: @@ -1330,11 +1342,12 @@ def act_lu( transpose_batch_sequence=transpose_batch_sequence, output_amax_when_no_scaling=True, ) - out, _ = _quantize_dbias_impl( + assert ( + out.data.dtype == x.dtype + ), f"Activation output dtype {out.data.dtype} must match input dtype {x.dtype}" + out = quantize( out, - is_dbias=False, quantizer=quantizer, - dq_dtype=x.dtype, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence, ) @@ -1419,7 +1432,23 @@ def quantize_dact_dbias( if not PrimitiveClass.enabled() or ( quantizer is not None and quantizer.q_layout.is_colwise_only ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer, act_params) + if quantizer is None: + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, act_params=act_params) + dact_out, _ = _jax_quantize_dact_dbias( + dz, x, activation_type, is_dbias=False, act_params=act_params + ) + assert ( + dact_out.data.dtype == x.dtype + ), f"JAX dact output dtype {dact_out.data.dtype} must match input dtype {x.dtype}" + return quantize_dbias( + dact_out, + quantizer, + is_dbias=is_dbias, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + if quantizer is None: output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind( dz, @@ -1465,7 +1494,7 @@ def quantize_dact_dbias( output_amax_when_no_scaling=output_amax_when_no_scaling, ) return _quantize_dbias_impl( - out.data, + out, quantizer, is_dbias=True, dq_dtype=x.dtype, From b8a402495948d87c55355ef2f8ee4a912dccde98 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Mon, 2 Feb 2026 14:16:13 -0600 Subject: [PATCH 129/141] [ROCm] resolve the conflicts in common dir --- hipify_custom_map.json | 3 +- transformer_engine/common/CMakeLists.txt | 161 +- transformer_engine/common/__init__.py | 63 +- transformer_engine/common/cast/cast.cu | 2 + .../common/cast/core/common.cuh | 2 + .../common/cast/dispatch/dequantize.cuh | 8 + .../common/cast/dispatch/gated.cuh | 18 + .../common/cast/dispatch/quantize.cuh | 8 + .../common/cast/fp8/dequantize_fp8.cuh | 2 + .../common/cast/fp8/gated_fp8.cuh | 4 + .../common/cast/fp8/quantize_fp8.cuh | 159 ++ .../common/cast/mxfp8/dequantize_mxfp8.cuh | 55 +- .../common/cast/mxfp8/gated_mxfp8.cuh | 563 +----- .../common/cast/mxfp8/quantize_mxfp8.cuh | 40 +- .../mxfp8/rocm_dequantize_mxfp8.cuh} | 26 +- .../mxfp8/rocm_gated_mxfp8.cuh} | 77 +- .../mxfp8/rocm_quantize_mxfp8.cuh} | 206 +-- transformer_engine/common/common.cu | 2 +- transformer_engine/common/common.h | 29 +- .../common/fused_attn_rocm/fused_attn.cpp | 124 +- .../fused_attn_rocm/fused_attn_aotriton.cpp | 6 +- .../fused_attn_rocm/fused_attn_aotriton.h | 1 + .../common/fused_attn_rocm/fused_attn_ck.cpp | 9 + .../common/fused_attn_rocm/fused_attn_ck.h | 1 + .../common/gemm/cublaslt_gemm.cu | 56 +- transformer_engine/common/gemm/rocm_gemm.cu | 5 +- .../include/transformer_engine/fused_attn.h | 10 +- .../common/normalization/common.h | 17 +- .../common/normalization/layernorm/ln_api.cpp | 4 - .../normalization/rmsnorm/rmsnorm_api.cpp | 4 - transformer_engine/common/recipe/__init__.py | 22 +- .../common/recipe/current_scaling.cu | 17 +- transformer_engine/common/swizzle/swizzle.cu | 131 -- .../common/util/cast_kernels.cuh | 1546 ----------------- transformer_engine/common/util/logging.h | 5 +- transformer_engine/common/util/ptx.cuh | 108 +- .../common/util/rocm_vectorized_2d.cuh | 68 - transformer_engine/common/utils.cuh | 3 - 38 files changed, 597 insertions(+), 2968 deletions(-) rename transformer_engine/common/{util/rocm_dequantize_kernels.cuh => cast/mxfp8/rocm_dequantize_mxfp8.cuh} (89%) rename transformer_engine/common/{util/rocm_cast_gated_kernels.cuh => cast/mxfp8/rocm_gated_mxfp8.cuh} (87%) rename transformer_engine/common/{util/rocm_cast_kernels.cuh => cast/mxfp8/rocm_quantize_mxfp8.cuh} (66%) delete mode 100644 transformer_engine/common/util/cast_kernels.cuh diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 97824bbdb..812ea384d 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -6,7 +6,8 @@ "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", "" : "", - "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)" + "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", + "__nv_bfloat162":"__hip_bfloat162" } } diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 399137484..46eb5dba5 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -33,20 +33,8 @@ else() endif() # Language options -<<<<<<< HEAD if(USE_CUDA) # Removed indent to minimize code diff with NV upstream -if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) - elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) - else () - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) - endif() -endif() -======= ->>>>>>> 389a6b set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) @@ -180,31 +168,21 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) -<<<<<<< HEAD -# Source files in both cuda and rocm -list(APPEND transformer_engine_SOURCES -======= set(transformer_engine_cpp_sources) set(transformer_engine_cuda_sources) set(transformer_engine_cuda_arch_specific_sources) +# Source files in both cuda and rocm list(APPEND transformer_engine_cpp_sources - cudnn_utils.cpp ->>>>>>> 389a6b transformer_engine.cpp - fused_attn/fused_attn.cpp gemm/config.cpp normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/rmsnorm/rmsnorm_api.cpp util/cuda_driver.cpp - util/cuda_nvml.cpp util/cuda_runtime.cpp util/multi_stream.cpp - util/rtc.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/comm_gemm_overlap.cpp) + util/rtc.cpp) list(APPEND transformer_engine_cuda_sources common.cu @@ -218,43 +196,18 @@ list(APPEND transformer_engine_cuda_sources transpose/cast_transpose_fusion.cu transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu -<<<<<<< HEAD -======= - transpose/quantize_transpose_vector_blockwise.cu ->>>>>>> 389a6b transpose/swap_first_dims.cu dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu -<<<<<<< HEAD - activation/relu.cu - activation/swiglu.cu gemm/cublaslt_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp -======= - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - fused_attn/fused_attn_fp8.cu - fused_attn/utils.cu - gemm/cublaslt_gemm.cu ->>>>>>> 389a6b normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu util/padding.cu -<<<<<<< HEAD - util/cuda_driver.cpp - util/cuda_runtime.cpp - util/multi_stream.cpp - util/rtc.cpp -======= ->>>>>>> 389a6b - swizzle/swizzle.cu - swizzle/swizzle_block_scaling.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -264,41 +217,51 @@ list(APPEND transformer_engine_cuda_sources fused_router/fused_topk_with_score_function.cu recipe/current_scaling.cu recipe/delayed_scaling.cu -<<<<<<< HEAD recipe/fp8_block_scaling.cu) -if(USE_CUDA) -# Removed indent to minimize code diff with NV upstream -# Files unique in cuda building -list(APPEND transformer_engine_SOURCES - cudnn_utils.cpp - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise.cu - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - fused_attn/fused_attn_fp8.cu - fused_attn/fused_attn.cpp - fused_attn/utils.cu - gemm/cutlass_grouped_gemm.cu - util/cuda_nvml.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) -======= - recipe/fp8_block_scaling.cu - recipe/nvfp4.cu - comm_gemm_overlap/userbuffers/userbuffers.cu) list(APPEND transformer_engine_cuda_arch_specific_sources - gemm/cutlass_grouped_gemm.cu cast/cast.cu activation/gelu.cu activation/relu.cu - activation/swiglu.cu - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise_fp4.cu - hadamard_transform/hadamard_transform.cu - hadamard_transform/hadamard_transform_cast_fusion.cu) + activation/swiglu.cu) + +if(USE_CUDA) +#NV specific source codes + list(APPEND transformer_engine_cpp_sources + cudnn_utils.cpp + fused_attn/fused_attn.cpp + util/cuda_nvml.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/comm_gemm_overlap.cpp) + list(APPEND transformer_engine_cuda_sources + transpose/quantize_transpose_vector_blockwise.cu + fused_attn/fused_attn_f16_max512_seqlen.cu + fused_attn/fused_attn_f16_arbitrary_seqlen.cu + fused_attn/fused_attn_fp8.cu + fused_attn/utils.cu + swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu + recipe/nvfp4.cu + comm_gemm_overlap/userbuffers/userbuffers.cu) + list(APPEND transformer_engine_cuda_arch_specific_sources + gemm/cutlass_grouped_gemm.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu) +else() +#ROCm specific source codes + list(APPEND transformer_engine_cpp_sources + fused_attn_rocm/fused_attn.cpp + gemm/rocm_gemm.cu + amd_detail/system.cpp) + list(APPEND transformer_engine_cuda_sources + fused_attn_rocm/fused_attn_aotriton.cpp + fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/utils.cpp) +endif() + # Compiling the files with the worst compilation time first to hopefully overlap # better with the faster-compiling cpp files @@ -306,6 +269,7 @@ list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_s ${transformer_engine_cuda_sources} ${transformer_engine_cpp_sources}) +if(USE_CUDA) # Set compile options for CUDA sources with generic architectures foreach(cuda_source IN LISTS transformer_engine_cuda_sources) set(arch_compile_options) @@ -339,7 +303,6 @@ foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) ) endif() endforeach() ->>>>>>> 389a6b if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES @@ -347,14 +310,8 @@ list(APPEND transformer_engine_SOURCES endif() add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) + else() - list(APPEND transformer_engine_SOURCES - fused_attn_rocm/fused_attn.cpp - fused_attn_rocm/fused_attn_aotriton.cpp - fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp - gemm/rocm_gemm.cu - amd_detail/system.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -386,32 +343,20 @@ else() message(STATUS "nvte hipified sources: ${te_hip_sources}") add_library(transformer_engine SHARED ${te_hip_sources}) - target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") + target_include_directories(transformer_engine PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") -<<<<<<< HEAD if (USE_CUDA) -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() -endif() #USE_CUDA -======= # CUTLASS kernels require SM90a and cause hang in debug build set_property( SOURCE gemm/cutlass_grouped_gemm.cu APPEND PROPERTY COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") ->>>>>>> 389a6b +endif() #USE_CUDA # Configure dependencies if (USE_CUDA) @@ -567,22 +512,7 @@ target_include_directories(transformer_engine PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/string_headers") # Compiler options -<<<<<<< HEAD -set_source_files_properties(fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") if(USE_CUDA) -======= set(nvte_sources_with_fast_math) list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -596,7 +526,6 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu) ->>>>>>> 389a6b option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index f8b302d49..cdda37508 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -17,16 +17,9 @@ import subprocess import sys import sysconfig -<<<<<<< HEAD -from typing import Optional - -import transformer_engine - -_logger = logging.getLogger(__name__) -======= from typing import Optional, Tuple ->>>>>>> 389a6b +import transformer_engine @functools.lru_cache(maxsize=None) def _is_package_installed(package) -> bool: @@ -145,8 +138,10 @@ def get_te_core_package_info() -> Tuple[bool, str, str]: Check if Tranformer Engine core package is installed. Returns the module name and version if found. """ - + te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") + if te_rocm_build: + te_core_packages = ("transformer-engine-rocm") for package in te_core_packages: if _is_package_installed(package): return True, package, version(package) @@ -171,42 +166,6 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" -<<<<<<< HEAD - te_cuda_vers = "rocm" if te_rocm_build else "cu12" - - # If the framework extension pip package is installed, it means that TE is installed via - # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework - # extension are all installed via PyPI and have matching version. - if _is_pip_package_installed(module_name): - assert _is_pip_package_installed( - "transformer_engine" - ), "Could not find `transformer-engine`." - assert _is_pip_package_installed( - f"transformer_engine_{te_cuda_vers}" - ), f"Could not find `transformer-engine-{te_cuda_vers}`." - assert ( - version(module_name) - == version("transformer-engine") - == version(f"transformer-engine-{te_cuda_vers}") - ), ( - "TransformerEngine package version mismatch. Found" - f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}" - f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" - ) - - # If the core package is installed via PyPI, log if - # the framework extension is not found from PyPI. - # Note: Should we error? This is a rare use case. - if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"): - if not _is_pip_package_installed(module_name): - _logger.info( - "Could not find package %s. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'", - module_name, - ) -======= # Find the TE packages. The core and framework packages can only be installed via PyPI. # For the `transformer-engine` package, we need to check explicity. te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() @@ -230,7 +189,6 @@ def load_framework_extension(framework: str) -> None: f" v{te_core_version}. Install transformer-engine using " f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'" ) ->>>>>>> 389a6b # After all checks are completed, load the shared object file. spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) @@ -438,7 +396,6 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): -<<<<<<< HEAD try: _CUDNN_LIB_CTYPES = _load_cudnn() _NVRTC_LIB_CTYPES = _load_nvrtc() @@ -446,9 +403,6 @@ def _load_core_library(): _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") - # Needed to find the correct headers for NVRTC kernels. - if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): - os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir() except (OSError, subprocess.CalledProcessError): pass finally: @@ -473,13 +427,4 @@ def _load_core_library(): assert (rocm_version == build_rocm_version), f"ROCm {'.'.join(rocm_version)} is detected but the library is built for {'.'.join(build_rocm_version)}" except FileNotFoundError: pass -======= - sanity_checks_for_pypi_installation() - _CUDNN_LIB_CTYPES = _load_cudnn() - _NVRTC_LIB_CTYPES = _load_nvrtc() - _CURAND_LIB_CTYPES = _load_curand() - _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") - _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") - _TE_LIB_CTYPES = _load_core_library() ->>>>>>> 389a6b diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 1ed46a335..575106a53 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -5,7 +5,9 @@ ************************************************************************/ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include #include diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index b750142f5..ec36e941f 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index b8547915c..f55719852 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -16,7 +16,9 @@ #include "../../common.h" #include "../fp8/dequantize_fp8.cuh" #include "../mxfp8/dequantize_mxfp8.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/dequantize_nvfp4.cuh" +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -34,17 +36,23 @@ inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t break; } case NVTE_MXFP8_1D_SCALING: { +#ifndef __HIP_PLATFORM_AMD__ if (is_supported_by_CC_100()) { +#endif //#ifndef __HIP_PLATFORM_AMD__ mxfp8::dequantize(input, output, stream); +#ifndef __HIP_PLATFORM_AMD__ } else { NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); } +#endif //#ifndef __HIP_PLATFORM_AMD__ break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { nvfp4::dequantize(input, output, stream); break; } +#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); } diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh index 4373090b7..8f236023b 100644 --- a/transformer_engine/common/cast/dispatch/gated.cuh +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -45,6 +47,9 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8::cast_gated_fwd(input, output, p, stream); +#else const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); if (use_tma_kernels) { Tensor dummy_grad_tensor; @@ -53,6 +58,7 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp } else { fp8::cast_gated_fwd(input, output, p, stream); } +#endif //#ifdef __HIP_PLATFORM_AMD__ break; } case NVTE_MXFP8_1D_SCALING: { @@ -68,8 +74,12 @@ void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_outp NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } +#ifdef __HIP_PLATFORM_AMD__ + //TODO: add gfx950 equivalent checking +#else NVTE_CHECK(is_supported_by_CC_100(), "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); +#endif Tensor dummy_grad_tensor; mxfp8::quantize_gated(input, dummy_grad_tensor, output, p, stream); @@ -122,6 +132,9 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); +#else const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); if (use_tma_kernels) { fp8::cast_gated_tma(gated_input, grad, output, p, @@ -129,6 +142,7 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte } else { fp8::cast_gated_bwd(gated_input, grad, output, p, stream); } +#endif //#ifdef __HIP_PLATFORM_AMD__ break; } case NVTE_MXFP8_1D_SCALING: { @@ -144,8 +158,12 @@ void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), "The type of the columnwise output tensor should be FP8."); } +#ifdef __HIP_PLATFORM_AMD__ + // add gfx950 equivalent check +#else NVTE_CHECK(is_supported_by_CC_100(), "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); +#endif //#ifdef __HIP_PLATFORM_AMD__ mxfp8::quantize_gated(gated_input, grad, output, p, stream); diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 9f7a4a9b0..8e8993668 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -19,8 +21,10 @@ #include "../core/common.cuh" #include "../fp8/quantize_fp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" +#ifndef __HIP_PLATFORM_AMD__ #include "../nvfp4/quantize_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace dispatch { @@ -87,6 +91,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, dummy_workspace_tensor, stream); break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); @@ -167,6 +172,7 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } +#endif//#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } @@ -232,6 +238,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens stream); break; } +#ifndef __HIP_PLATFORM_AMD__ case NVTE_NVFP4_1D_SCALING: { NVTE_CHECK((!IS_DBIAS && !IS_DACT), "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); @@ -315,6 +322,7 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens columnwise_option, force_pow_2_scales, noop_tensor->data, stream); break; } +#endif //#ifndef __HIP_PLATFORM_AMD__ default: NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); } diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 2514758b5..5d30a6c3f 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif // #ifndef __HIP_PLATFORM_AMD__ #include #include diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index 225ef93ed..c9040a3da 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -12,7 +12,9 @@ #define TRANSFORMER_ENGINE_GATED_FP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -25,6 +27,7 @@ namespace transformer_engine { namespace dispatch { namespace fp8 { +#ifndef __HIP_PLATFORM_AMD__ namespace kernel { constexpr size_t CHUNK_DIM_Y = 128; @@ -348,6 +351,7 @@ void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *outpu NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) } +#endif //#ifndef __HIP_PLATFORM_AMD__ template void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index efc5015b7..9de093e96 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -12,7 +14,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include #include @@ -35,6 +39,58 @@ namespace transformer_engine { namespace dispatch { namespace fp8 { +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t TILE_DIM = 32; +template +__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { + __shared__ float tile[TILE_DIM][TILE_DIM]; + + int tile_start_col = blockIdx.x * TILE_DIM; + int tile_start_row = blockIdx.y * TILE_DIM; + int thread_col_in_tile = threadIdx.x; + int thread_row_in_tile = threadIdx.y; + + int global_col = tile_start_col + thread_col_in_tile; + int global_row = tile_start_row + thread_row_in_tile; + + if (global_row < rows && global_col < cols) { + tile[thread_row_in_tile][thread_col_in_tile] = static_cast(input[global_row * cols + global_col]); + } else { + tile[thread_row_in_tile][thread_col_in_tile] = 0.0f; + } + __syncthreads(); + + for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { + if (thread_row_in_tile < stride) { + tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; + } + __syncthreads(); + } + + if (thread_row_in_tile == 0 && global_col < cols) { + partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile]; + } +} + +template +void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, + const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) { + dim3 block_dim_partial(TILE_DIM, TILE_DIM); + dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); + + const size_t partial_rows = grid_dim_partial.y; + float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); + + partial_reduce_kernel<<>>( + workspace_ptr, + partial_workspace, + rows, cols); + + common::reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); +} + + +#else namespace quantize_2D_kernel { constexpr size_t FP8_CHUNK_DIM_Y = 128; @@ -454,16 +510,33 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T }); // NOLINT(*) ); // NOLINT(*) } +#endif //#ifdef __HIP_PLATFORM_AMD__ namespace detail { using Empty = transformer_engine::Empty; __device__ inline float identity(float value, const Empty &) { return value; } } // namespace detail +/* HIPCC has strict rules for __device__ functions usage on host. + It forbids not only calling but also other ODR-use assigning to variables + https://github.com/llvm/llvm-project/issues/105825 + Use templated struct wrapper to work around + */ +template +struct ActivationType +{ + static constexpr auto op = OP; +}; + + template void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; +#else //#ifdef __HIP_PLATFORM_AMD__ constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; +#endif //#ifdef __HIP_PLATFORM_AMD__ const size_t N = product(input.data.shape); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, @@ -487,7 +560,11 @@ void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, template void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; +#else //#ifdef __HIP_PLATFORM_AMD__ constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; +#endif //#ifdef __HIP_PLATFORM_AMD__ const size_t N = product(input->data.shape); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input->data.dtype, IType, @@ -512,7 +589,9 @@ template void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { +#ifndef __HIP_PLATFORM_AMD__ using namespace quantize_1D_kernel; +#endif //#ifndef __HIP_PLATFORM_AMD__ CheckNoopTensor(*noop, "cast_noop"); CheckInputTensor(input, "cast_input"); CheckOutputTensor(*output, "cast_output"); @@ -531,6 +610,85 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); +#ifdef __HIP_PLATFORM_AMD__ + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); + if (workspace->data.dptr == nullptr) { + if constexpr (IS_DACT) { + const size_t partial_rows = DIVUP(rows, TILE_DIM); + size_t total_elements = (rows * cols) + (partial_rows * cols); + workspace->data.shape = {total_elements}; + workspace->data.dtype = DType::kFloat32; + } else { + workspace->data.shape = {rows, cols}; + workspace->data.dtype = DType::kFloat32; + } + return; + } + + const void *ptr_to_reduce = nullptr; + DType dtype_to_reduce; + + workspace->amax = {}; + workspace->scale = {}; + workspace->scale_inv = {}; + + Tensor workspace_buffer; + Tensor partial_sum_buffer; + + if constexpr (IS_DACT) { + // The values to reduce are the result of the dAct function. + NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); + + const size_t partial_rows = DIVUP(rows, TILE_DIM); + const size_t full_size_bytes = rows * cols * sizeof(float); + workspace_buffer = *workspace; + workspace_buffer.data.shape = {rows, cols}; + partial_sum_buffer.data.dptr = reinterpret_cast(workspace->data.dptr) + full_size_bytes; + partial_sum_buffer.data.shape = {partial_rows, cols}; + partial_sum_buffer.data.dtype = DType::kFloat32; + workspace = &partial_sum_buffer; + + CastVectorizedUnaryGradKernelLauncher(input, act_input, &workspace_buffer, stream); + if (output && output->data.dptr) { + CastVectorizedUnaryKernelLauncher(workspace_buffer, noop, output, stream); + } + ptr_to_reduce = workspace_buffer.data.dptr; + dtype_to_reduce = workspace_buffer.data.dtype; + } else { + if (output && output->data.dptr) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + // The values to reduce are just the input values. + ptr_to_reduce = input.data.dptr; + dtype_to_reduce = input.data.dtype; + } + + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias tensor."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dbias->data.dtype, DBiasTypeOut, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dtype_to_reduce, DTypeReduce, + reduce_dbias_rocm( + reinterpret_cast(ptr_to_reduce), + dbias, rows, cols, stream, workspace); + ); + ); + } else { + if (output && output->data.dptr) { + if constexpr (IS_DACT) { + NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output."); + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } else { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } + } +#else // Supported by the Arch >= 10.0 if (is_supported_by_CC_100()) { if (!IS_DBIAS && !IS_DACT) { @@ -571,6 +729,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); } } +#endif //#ifdef __HIP_PLATFORM_AMD__ } } // namespace fp8 diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 89391b21f..96aed3e88 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -20,34 +20,18 @@ #include #include -<<<<<<< HEAD:transformer_engine/common/util/dequantize_kernels.cuh -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transpose.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_dequantize_kernels.cuh" -#endif -======= #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" #include "../../utils.cuh" ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh namespace transformer_engine { namespace dispatch { namespace mxfp8 { namespace dequantize_kernel { - -#ifndef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_dequantize_mxfp8.cuh" +#else template __global__ void __launch_bounds__(THREADS_PER_CHUNK) dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, @@ -225,11 +209,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -<<<<<<< HEAD:transformer_engine/common/util/dequantize_kernels.cuh #endif // #ifndef __HIP_PLATFORM_AMD__ -======= } // namespace dequantize_kernel ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { using namespace dequantize_kernel; @@ -328,39 +309,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) #endif NVTE_CHECK_CUDA(cudaGetLastError()); } -<<<<<<< HEAD:transformer_engine/common/util/dequantize_kernels.cuh -} // namespace dequantization - -namespace detail { - -void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - if (1) { -#else - if (is_supported_by_CC_100()) { -#endif - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } -} - -} // namespace detail -======= } // namespace mxfp8 } // namespace dispatch ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index a59e85659..28e46fc7a 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -20,258 +20,6 @@ #include #include -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh -#include - -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_cast_gated_kernels.cuh" -#endif - -namespace transformer_engine { - -namespace gated_kernels { - -#ifndef __HIP_PLATFORM_AMD__ -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 512; -constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 -constexpr size_t BUFFERS_NUM = 2; -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); - -__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act, - const __grid_constant__ CUtensorMap tensor_map_output_gate, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - - constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t in_transaction_size = buff_elems * sizeof(IType); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); - const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - // Prefetch data of the first stage - - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, - TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, - chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } else { - copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } - -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const size_t buff = it % BUFFERS_NUM; - const size_t next_it = it + 1; - if (next_it < ITERATIONS) { - const size_t next_buff = next_it % BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3( - &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, - &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, - in_transaction_size, &mbar[next_it], is_master_thread); - } else { - copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, - chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, - &mbar[next_it], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_sh_curr = out_act_sh + buff * buff_elems; - OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; - -#pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); - } - - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; - - out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; - out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // dGeLU - ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, - chunk_it_offset_y, - reinterpret_cast(out_act_sh_curr)); - - if constexpr (IS_DGATED) { - // dGate - ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_sh_curr)); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -namespace mxfp8_kernel { -======= #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" @@ -281,8 +29,9 @@ namespace transformer_engine { namespace dispatch { namespace mxfp8 { namespace gated_kernel { ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh - +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_gated_mxfp8.cuh" +#else constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; @@ -925,99 +674,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace gated_kernel template void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, cudaStream_t stream) { -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh - checkCuDriverContext(stream); - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act{}; - alignas(64) CUtensorMap tensor_map_output_gate{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} -#endif //#ifdef __HIP_PLATFORM_AMD__ - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { -======= using namespace gated_kernel; ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh checkCuDriverContext(stream); const bool USE_ROWWISE_SCALING = output->has_data(); @@ -1045,26 +709,18 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_BWD ? 2 : 1) * cols; -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh #ifdef __HIP_PLATFORM_AMD__ constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; constexpr size_t BUFFS_NUM = BUFFERS_NUM; +#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); -#else - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; -======= - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +#ifndef __HIP_PLATFORM_AMD__ const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; @@ -1087,7 +743,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *tensor_map_grad = IS_BWD ? reinterpret_cast(grad.data.dptr) : nullptr; const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; @@ -1153,15 +809,11 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh #ifdef __HIP_PLATFORM_AMD__ const size_t out_gate_mem = buff_size_aligned_out; #else - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); -#endif -======= const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +#endif size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } @@ -1175,18 +827,18 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, + quantize_gated_mxfp8_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - cast_mxfp8_gated_kernel + quantize_gated_mxfp8_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); }))); // NOLINT(*) #else @@ -1232,200 +884,15 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; -<<<<<<< HEAD:transformer_engine/common/util/cast_gated_kernels.cuh - } + } + } NVTE_CHECK_CUDA(cudaGetLastError()); // NOLINT(*) #endif ); // NOLINT(*) ); // NOLINT(*) } -template -void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, - "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, - "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match. Input shape: ", input.data.shape, - ", output shape: ", output->data.shape, "."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { - constexpr bool allow_empty = false; - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", allow_empty); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - if constexpr (IS_DGATED) { - CheckInputTensor(grad, "grad"); - NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); - NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); - } - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - bool is_fp8_rowwise_output = true; - bool is_fp8_colwise_output = true; - if (output->has_data()) { - is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - if (output->has_columnwise_data()) { - is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - - const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; - - if (is_delayed_tensor_scaling(output->scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } -#else - if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } - } -#endif - } else if (is_mxfp_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } - } else { - NVTE_ERROR("Not supported scaling mode"); - } -} -} // namespace gated_kernels - -namespace detail { - -template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - cudaStream_t stream) { - using namespace gated_kernels; - Tensor grad_empty_tensor; - const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); - Tensor *output_tensor = convertNVTETensorCheck(output); - -#ifdef __HIP_PLATFORM_AMD__ - if (1) { -#else - if (is_supported_by_CC_100()) { -#endif - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); - } else { - cast_gated(gated_input_tensor, output_tensor, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } -} -} // namespace detail - -======= - } - } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} - } // namespace mxfp8 } // namespace dispatch ->>>>>>> 389a6b:transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 5505de605..19234e9b4 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -12,7 +14,9 @@ #define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ #include +#ifndef __HIP_PLATFORM_AMD__ #include +#endif //#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -26,7 +30,9 @@ namespace transformer_engine { namespace dispatch { namespace mxfp8 { namespace quantize_kernel { - +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_quantize_mxfp8.cuh" +#else constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -536,6 +542,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // namespace quantize_kernel template has_data(); bool use_colwise_scaling = output->has_columnwise_data(); @@ -562,6 +571,11 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; + constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; + constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; +#else constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; @@ -572,6 +586,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); @@ -589,6 +604,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; +#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { scaling_type = ScalingType::ROWWISE; @@ -597,6 +613,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } else if (use_rowwise_scaling && use_colwise_scaling) { scaling_type = ScalingType::BIDIMENSIONAL; } +#endif if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); @@ -619,6 +636,26 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, +#ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + quantize_mxfp8_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))); // NOLINT(*) +#else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_output_rowwise{}; @@ -708,6 +745,7 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, break; } } +#endif // #ifdef __HIP_PLATFORM_AMD__ if constexpr (IS_DBIAS) { common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh similarity index 89% rename from transformer_engine/common/util/rocm_dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index 398e4c0ad..02224a69f 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -5,26 +5,7 @@ ************************************************************************/ #pragma once - -#include -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/cast.h" -#include "transpose/cast_transpose.h" -#include "transformer_engine/transpose.h" -#include "utils.cuh" -#include "vectorized_pointwise.h" - -namespace transformer_engine { - -namespace dequantization { +// drop-in rocm replacement for mxfp8 dequantize kernel constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -127,12 +108,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); - bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } } -} // namespace dequantization -} // namespace transformer_engine + diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh similarity index 87% rename from transformer_engine/common/util/rocm_cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index a53fd51c5..7382b8aab 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -5,22 +5,7 @@ ************************************************************************/ #pragma once - -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/cast.h" -#include "vectorized_pointwise.h" -#include "utils.cuh" - -namespace transformer_engine { -namespace gated_kernels { +// drop-in rocm replacement for mxfp8 gated quantize kernel constexpr size_t ALIGNMENT_SIZE = 128; // TODO: Identify optimal chunk/thread size for MI350+ @@ -45,16 +30,17 @@ template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const IType *grad_ptr, - const IType *input_act, - const IType *input_gate, - OType *output_act_rowwise, - OType *output_gate_rowwise, - OType *output_act_colwise, - OType *output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_gated_mxfp8_kernel( + const IType *grad_ptr, + const IType *input_act, + const IType *input_gate, + OType *output_act_rowwise, + OType *output_gate_rowwise, + OType *output_act_colwise, + OType *output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; @@ -171,24 +157,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh[shmem_idx]); float gate_elt = static_cast(in_gate_sh[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_idx]); const float x = act_elt; float act_x; float dact_x; - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; + after_dgate_reg[stage] = dgate_elt ? act_x * grad_elt : 0.0f; } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + after_dact_reg[stage] = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 @@ -355,24 +356,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } __syncthreads(); } } -} // namespace gated_kernels -} // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh similarity index 66% rename from transformer_engine/common/util/rocm_cast_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index e39e0a4a7..dc36fb42d 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -4,28 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once - -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/cast.h" -#include "transpose/cast_transpose.h" -#include "vectorized_pointwise.h" -#include "utils.cuh" - -namespace transformer_engine { - -// Forward declaration, definition is in cast_kernels.cuh -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); - +// drop-in replacement for rocm quantize_mxfp8 kernels constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; @@ -53,14 +32,15 @@ template __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const IType *input_ptr, - const IType *act_input_ptr, - OType *output_rowwise, - OType *output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_mxfp8_kernel( + const IType *input_ptr, + const IType *act_input_ptr, + OType *output_rowwise, + OType *output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if (noop != nullptr && noop[0] == 1.0f) return; } @@ -310,12 +290,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, + ptx::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } @@ -393,165 +373,3 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) atomicMaxFloat(amax_ptr, block_amax); } } - -// Forward declaration of functions defined in `cast_kernels.cuh` -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream); - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream); - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream); - -constexpr size_t TILE_DIM = 32; -template -__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { - __shared__ float tile[TILE_DIM][TILE_DIM]; - - int tile_start_col = blockIdx.x * TILE_DIM; - int tile_start_row = blockIdx.y * TILE_DIM; - int thread_col_in_tile = threadIdx.x; - int thread_row_in_tile = threadIdx.y; - - int global_col = tile_start_col + thread_col_in_tile; - int global_row = tile_start_row + thread_row_in_tile; - - if (global_row < rows && global_col < cols) { - tile[thread_row_in_tile][thread_col_in_tile] = static_cast(input[global_row * cols + global_col]); - } else { - tile[thread_row_in_tile][thread_col_in_tile] = 0.0f; - } - __syncthreads(); - - for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { - if (thread_row_in_tile < stride) { - tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; - } - __syncthreads(); - } - - if (thread_row_in_tile == 0 && global_col < cols) { - partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile]; - } -} - -template -void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, - const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) { - dim3 block_dim_partial(TILE_DIM, TILE_DIM); - dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); - - const size_t partial_rows = grid_dim_partial.y; - float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); - - partial_reduce_kernel<<>>( - workspace_ptr, - partial_workspace, - rows, cols); - - reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); -} - -template -void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true."); - NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); - if (workspace->data.dptr == nullptr) { - if constexpr (IS_DACT) { - const size_t partial_rows = DIVUP(rows, TILE_DIM); - size_t total_elements = (rows * cols) + (partial_rows * cols); - workspace->data.shape = {total_elements}; - workspace->data.dtype = DType::kFloat32; - } else { - workspace->data.shape = {rows, cols}; - workspace->data.dtype = DType::kFloat32; - } - return; - } - - const void *ptr_to_reduce = nullptr; - DType dtype_to_reduce; - - workspace->amax = {}; - workspace->scale = {}; - workspace->scale_inv = {}; - - Tensor workspace_buffer; - Tensor partial_sum_buffer; - - if constexpr (IS_DACT) { - // The values to reduce are the result of the dAct function. - NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); - - const size_t partial_rows = DIVUP(rows, TILE_DIM); - const size_t full_size_bytes = rows * cols * sizeof(float); - workspace_buffer = *workspace; - workspace_buffer.data.shape = {rows, cols}; - partial_sum_buffer.data.dptr = reinterpret_cast(workspace->data.dptr) + full_size_bytes; - partial_sum_buffer.data.shape = {partial_rows, cols}; - partial_sum_buffer.data.dtype = DType::kFloat32; - workspace = &partial_sum_buffer; - - CastVectorizedUnaryGradKernelLauncher(input, act_input, &workspace_buffer, stream); - if (output && output->data.dptr) { - CastVectorizedUnaryKernelLauncher(workspace_buffer, noop, output, stream); - } - ptr_to_reduce = workspace_buffer.data.dptr; - dtype_to_reduce = workspace_buffer.data.dtype; - } else { - if (output && output->data.dptr) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - // The values to reduce are just the input values. - ptr_to_reduce = input.data.dptr; - dtype_to_reduce = input.data.dtype; - } - - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias tensor."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dbias->data.dtype, DBiasTypeOut, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dtype_to_reduce, DTypeReduce, - reduce_dbias_rocm( - reinterpret_cast(ptr_to_reduce), - dbias, rows, cols, stream, workspace); - ); - ); - } else { - if (output && output->data.dptr) { - if constexpr (IS_DACT) { - NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output."); - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } else { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - - -} // namespace transformer_engine diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index af3a51373..ab574256c 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -230,13 +230,13 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, // Any element that is outside of bounds will be set to zero by the TMA transfer. CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); } -#endif //#ifndef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100() { int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); return deviceComputeCapability >= 100; } +#endif //#ifndef __HIP_PLATFORM_AMD__ std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ae90ea4e5..03b90febb 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -354,7 +354,8 @@ using fp8e8m0 = __nv_fp8_e8m0; #endif // CUDA_VERSION >= 12080 #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; -<<<<<<< HEAD +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif //FP4_TYPE_SUPPORTED #else using bf16 = hip_bfloat16; @@ -362,11 +363,6 @@ using fp8e4m3 = te_hip_fp8_e4m3; using fp8e5m2 = te_hip_fp8_e5m2; #endif //__HIP_PLATFORM_AMD__ -======= -using fp4e2m1x2 = __nv_fp4x2_e2m1; -using fp4e2m1x4 = __nv_fp4x4_e2m1; -#endif ->>>>>>> 389a6b using e8m0_t = uint8_t; namespace detail { @@ -416,15 +412,14 @@ template <> struct TypeExtrema { #ifndef __HIP_PLATFORM_AMD__ static constexpr float max = 448.0f; -<<<<<<< HEAD + static constexpr float max_inverse = 1.0 / max; #elif defined(__HIP_DEVICE_COMPILE__) - static constexpr float maxNorm = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max_inverse = 1.0 / max; #else - static float maxNorm; + static float max; + static float max_inverse; #endif -======= - static constexpr float max_inverse = 1.0 / max; ->>>>>>> 389a6b }; template <> @@ -820,21 +815,15 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. -<<<<<<< HEAD -void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, - const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, - const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits); -#endif //#ifdef __HIP_PLATFORM_AMD__ -======= void create_2D_tensor_map( CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); ->>>>>>> 389a6b bool is_supported_by_CC_100(); +#endif //#ifdef __HIP_PLATFORM_AMD__ + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index bb5e22887..d39fccbce 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -276,9 +276,10 @@ void log_fused_attn_config( // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; // by default, fused attn is enabled @@ -311,6 +312,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, @@ -325,6 +327,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, @@ -339,12 +342,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool cuda_graph, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); @@ -384,9 +389,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); - + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, + cuda_graph); + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( b, h, max_seqlen, d, @@ -416,15 +422,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } // NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -468,8 +473,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ @@ -505,14 +510,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -556,8 +564,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -596,11 +605,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; @@ -649,9 +659,10 @@ void nvte_fused_attn_bwd_kvpacked( // fix the incompatible window size from upstream frameworks pytorch/jax std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -694,14 +705,16 @@ void nvte_fused_attn_bwd_kvpacked( // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); @@ -740,8 +753,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -780,14 +794,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -830,8 +845,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index a8a151b40..1c25fa031 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -44,6 +44,7 @@ bool is_aotriton_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -68,7 +69,10 @@ bool is_aotriton_backend_supported( if(!(is_no_mask_window_size || is_causal_mask_window_size)){ return false; } - + + if(softmax_type!=NVTE_VANILLA_SOFTMAX){ + return false; + } //aotriton fused attn does not support gqa mode now if(num_attn_heads!=num_gqa_groups){ return false; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index b016acc67..178bd8d8f 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -23,6 +23,7 @@ bool is_aotriton_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7ca6fc95f..8d639c47c 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -26,6 +26,7 @@ bool is_ck_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -80,6 +81,14 @@ bool is_ck_backend_supported( return false; } + // filter based on softmax type + if(softmax_type!=NVTE_VANILLA_SOFTMAX){ + if(nvte_log_ck_config){ + std::cout<<"AITER/CK fused attn does not support learnable sink yet"<>>>>>> 389a6b +#endif #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -333,7 +328,7 @@ namespace transformer_engine { void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void* workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void* alpha_ptr, const void* beta_ptr, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset = -1); #else // Use cublasLt @@ -928,12 +923,9 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_atomic_gemm); using namespace transformer_engine; -<<<<<<< HEAD #ifndef __HIP_PLATFORM_AMD__ // Check CUDA and cuBLAS versions -======= ->>>>>>> 389a6b #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); @@ -951,7 +943,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); -#endif //__HIP_PLATFORM_AMD__ +#endif const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); @@ -971,7 +963,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); -#endif +#endif //#ifndef __HIP_PLATFORM_AMD__ } void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, @@ -992,28 +984,24 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens } for (int i = 0; i < num_gemms; i++) { -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ - { - const Tensor *inputA = convertNVTETensorCheck(A[i]); - const Tensor *inputB = convertNVTETensorCheck(B[i]); - Tensor *outputD = convertNVTETensorCheck(D[i]); - const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); - Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); - Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, - use_split_accumulator, math_sm_count, 0, 0, false, nullptr, - detail::get_compute_stream(i % num_streams), i % num_streams); - } + const Tensor *inputA = convertNVTETensorCheck(A[i]); + const Tensor *inputB = convertNVTETensorCheck(B[i]); + Tensor *outputD = convertNVTETensorCheck(D[i]); + const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); + Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); + Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); + + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], &alpha, &beta, + use_split_accumulator, math_sm_count, 0, 0, false, nullptr, + detail::get_compute_stream(i % num_streams), i % num_streams); #else - nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams)); -#endif -======= // Check whether GELU or dGELU epilogue is requested Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]); bool with_gelu_dgelu_epilogue = @@ -1038,7 +1026,7 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i], workspace[i % num_streams], &config, detail::get_compute_stream(i % num_streams)); ->>>>>>> 389a6b +#endif } // record events on compute streams diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index fef3966a5..97bd2e8a7 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1501,7 +1501,7 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void* alpha_ptr, const void* beta_ptr, bool use_split_accumulator, int math_sm_count, [[maybe_unused]] int m_split, [[maybe_unused]] int n_split, [[maybe_unused]] bool gemm_producer, [[maybe_unused]] const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset) @@ -1527,6 +1527,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int ldb = is_transb ? n : k; const int ldd = m; + float alpha = *reinterpret_cast(alpha_ptr); // Assumed to be on CPU + float beta = *reinterpret_cast(beta_ptr); // Assumed to be on CPU + ServiceStreamCtl ss_ctl; bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index e4a86698c..158d8ea5d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -307,15 +307,12 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, /*! \brief Compute the backward of the dot product attention with packed QKV input. * -<<<<<<< HEAD + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * Support Matrix for ROCm AOTriton: \verbatim | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | aotriton| FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | \endverbatim -======= - * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. ->>>>>>> 389a6b * * Support Matrix: \verbatim @@ -462,15 +459,12 @@ void nvte_fused_attn_fwd_kvpacked( /*! \brief Compute the backward of the dot product attention with packed KV input. * -<<<<<<< HEAD + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * Support Matrix for ROCm AOTriton: \verbatim | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | | aotriton| FP16/BF16 | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | NO | NO/CAUSAL | Yes | arbitrary | arbitrary | \endverbatim -======= - * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. ->>>>>>> 389a6b * * Support Matrix: \verbatim diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index cfcd91646..a5278522c 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -29,7 +29,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "../cudnn_utils.h" #else -#include "../util/rocm_cast_kernels.cuh" +#include "../cast/mxfp8/quantize_mxfp8.cuh" #endif #include "../util/system.h" @@ -447,10 +447,10 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) const size_t scale_dim_X_rowwise = 32; const size_t scale_dim_Y_colwise = launch_params.training ? 32 : 1; - const size_t chunks_Y = DIVUP(rows, transformer_engine::MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, transformer_engine::MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, transformer_engine::MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, transformer_engine::MXFP8_CHUNKS_PER_BLOCK_X); + const size_t chunks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_X); const size_t scale_stride_rowwise = launch_params.z_tensor->scale_inv.shape[1]; const size_t scale_stride_colwise = launch_params.training ? launch_params.z_tensor->columnwise_scale_inv.shape[1] : 1; @@ -459,17 +459,18 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) e8m0_t *const scales_colwise_ptr = launch_params.training ? reinterpret_cast(launch_params.z_tensor->columnwise_scale_inv.dptr) : nullptr; - const dim3 block(transformer_engine::MXFP8_THREADS_PER_CHUNK); + const dim3 block(dispatch::mxfp8::quantize_kernel::MXFP8_THREADS_PER_CHUNK); const dim3 grid(blocks_X, blocks_Y); + using namespace dispatch::mxfp8::quantize_kernel; TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( scale_dim_Y_colwise, SCALE_DIM_Y, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( launch_params.z_tensor->dtype(), OType, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(compute_t))), IS_ALIGNED, - cast_mxfp8_2D_kernel<<>>( + quantize_mxfp8_kernel<<>>( reinterpret_cast(launch_params.params.z), nullptr, reinterpret_cast(launch_params.z_tensor->data.dptr), diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index cec7da248..6c21eab7b 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -65,12 +65,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; -<<<<<<< HEAD #ifndef __HIP_PLATFORM_AMD__ - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); -======= bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); ->>>>>>> 389a6b if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 54815851d..598e0ca08 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -51,12 +51,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; -<<<<<<< HEAD #ifndef __HIP_PLATFORM_AMD__ - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); -======= bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); ->>>>>>> 389a6b if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 7c6055629..c55f1f612 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -8,15 +8,10 @@ from __future__ import annotations import os from enum import Enum -<<<<<<< HEAD -from typing import Optional, Union, Callable, NamedTuple -from typing_extensions import Literal -======= from typing import Any, Literal, Optional, Union, Callable, NamedTuple from dataclasses import field ->>>>>>> 389a6b from pydantic.dataclasses import dataclass -from transformer_engine.common import is_fp8_fnuz +from transformer_engine.common import is_fp8_fnuz, te_rocm_build class _FormatHelper(NamedTuple): @@ -58,17 +53,12 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ -<<<<<<< HEAD - E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(fwd=E4M3.fwd, bwd=E5M2.bwd) -======= - - E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) - E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) - E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) + #TODO: bring E2M1 back after rocm support MXFP4 + if not te_rocm_build: + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) + E4M3 = _FormatHelper(max_fwd=_FormatMaxVals.E4M3.value, max_bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelper(max_fwd=_FormatMaxVals.E5M2.value, max_bwd=_FormatMaxVals.E5M2.value) HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) ->>>>>>> 389a6b @dataclass(frozen=True) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index af8eaaf67..69b44494b 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -28,7 +28,6 @@ using bf16__ = __hip_bfloat16; constexpr int amax_kernel_threads = 512; -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ template @@ -52,7 +51,6 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax, #endif -======= __launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) { if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { return; @@ -60,7 +58,6 @@ __launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const flo *amax_ptr = 0; } ->>>>>>> 389a6b template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, @@ -280,19 +277,13 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt float *amax_ptr = reinterpret_cast( (output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( -<<<<<<< HEAD input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); - launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), + launch_amax_kernel( + reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), #ifdef __HIP_PLATFORM_AMD__ - block_amax, block_capacity, + block_amax, block_capacity, #endif - noop_ptr, stream);); // NOLINT(*) -======= - input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel( - reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), noop_ptr, - stream);); // NOLINT(*) ->>>>>>> 389a6b + noop_ptr, stream);); // NOLINT(*) } } // anonymous namespace diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index d0a7cb85e..881b134e7 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -20,18 +20,14 @@ namespace transformer_engine { namespace { -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ #define __ldg(x) (*(x)) #endif #ifndef __HIP_PLATFORM_AMD__ -constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; -======= constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int NVFP4_BLOCK_SIZE = 16; ->>>>>>> 389a6b constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -376,138 +372,11 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; // 1D block scaling, row-wise or colum-wise -<<<<<<< HEAD - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int m = - input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; - const int k = - input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); - NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); - } - - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; - - dim3 block_size(TB_DIM, TB_DIM); - if (input->has_data()) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 2: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 1: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); - } - if (input->has_columnwise_data()) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 2: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 1: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); - } - // 2D block scaling -======= int m, k; if (input->has_data()) { m = input->scale_inv.shape[0]; k = input->scale_inv.shape[1]; ->>>>>>> 389a6b } else { if (nvfp4) { m = input->columnwise_scale_inv.shape[0]; diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh deleted file mode 100644 index b7c4cf837..000000000 --- a/transformer_engine/common/util/cast_kernels.cuh +++ /dev/null @@ -1,1546 +0,0 @@ -/************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file cast_kernels.cuh - * \brief CUDA kernels to cast to/from FP8/MXFP8. - */ - -#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ - -#include -#ifndef __HIP_PLATFORM_AMD__ -#include -#endif //#ifndef __HIP_PLATFORM_AMD__ -#include -#include - -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_cast_kernels.cuh" -#endif - -namespace transformer_engine { - -#ifndef __HIP_PLATFORM_AMD__ -namespace mxfp8_kernel { - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 32; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; - static_assert(BUFF_DIM_Y == 32); - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } - - float block_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - - parity ^= 1; - - if constexpr (IS_DBIAS) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); - - constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - - const size_t shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const size_t shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const size_t scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; - } - } - const size_t dbias_stride = cols; - const size_t dbias_offset_Y = blockIdx.y; - const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace mxfp8_kernel - -constexpr size_t FP8_CHUNK_DIM_Y = 128; -constexpr size_t FP8_CHUNK_DIM_X = 128; -constexpr size_t FP8_THREADS_PER_CHUNK = 128; -constexpr size_t FP8_BUFFERS_NUM = 2; -constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); - -constexpr size_t FP8_BUFFER_DIM_Y = 16; -constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 -constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 - -constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); - -template -__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output, - float *const dbias_workspace, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, - const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - const size_t dbias_offset_Y = blockIdx.y + tid_Y; - const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; - const bool col_out_of_bounds = my_column >= cols; - const size_t dbias_stride = cols; - - float partial_dbias = 0.f; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - - constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const size_t chunk_stage_offset_X = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); - } - } - -#pragma unroll - for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const size_t buff = iter % FP8_BUFFERS_NUM; - const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; - if (next_iter < FP8_ITERATIONS) { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); - } - } - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = row >= rows; - const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if constexpr (IS_DACT) { - if (!out_of_bounds) { - partial_dbias += elt; - } - } else { - // If no activation, elt is 0 so we can safely do this - partial_dbias += elt; - } - } - __builtin_assume(amax >= 0); - if (IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); - } - out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - - if constexpr (IS_DBIAS) { - const size_t dbias_offset_X = my_column; - const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t CHUNKS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; -constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; -constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; -constexpr size_t CHUNKS_PER_ITERATION = 32; -constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; -constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; -constexpr size_t SHMEM_BUFFERS = 2; -static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); - -template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; - const IType *input = input_ptr + block_offset; - OType *output = output_ptr + block_offset; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - - constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; ++iter) { - const size_t buff = iter % SHMEM_BUFFERS; - const size_t it_offset = iter * SHMEM_DIM; - - const size_t next_iter = iter + 1; - const size_t next_buff = next_iter % SHMEM_BUFFERS; - const size_t next_iter_offset = next_iter * SHMEM_DIM; - - if (next_iter < ITERATIONS) { - copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, - &(mbar[next_iter]), is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; - float elt = static_cast(in_sh[buff][shmem_offset]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(elt)); - out_sh[buff][shmem_offset] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - ptx::cp_async_bulk_tensor_1d_shared_to_global( - reinterpret_cast(output + it_offset), - reinterpret_cast(&out_sh[buff]), transaction_size_OUT); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read<1>(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -#endif // #ifndef __HIP_PLATFORM_AMD__ - -constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; -template -__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, - const size_t rows, const size_t cols) { - using ComputeVec = Vec; - using OutputVec = Vec; - - const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - if (thread_id * nvec >= cols) { - return; - } - - const float *const thread_in_base = dbias_partial + thread_id * nvec; - OType *const thread_out_base = dbias_output + thread_id * nvec; - - ComputeVec ldg_vec; - ComputeVec acc_vec; - acc_vec.clear(); - for (int i = 0; i < rows; ++i) { - ldg_vec.load_from(thread_in_base + i * cols); -#pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } - - OutputVec stg_vec; -#pragma unroll - for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); - } - stg_vec.store_to(thread_out_base); -} - -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream) { - constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 - constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); - - NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); - const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); - - reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -#ifndef __HIP_PLATFORM_AMD__ -template -static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { - const size_t N = product(input.data.shape); - - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - const size_t chunks = DIVUP(N, CHUNK_SIZE); - const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - const float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(THREADS_PER_BLOCK); - const dim3 grid(blocks); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - const IType *input_ptr = reinterpret_cast(input.data.dptr); - OType *output_ptr = reinterpret_cast(output->data.dptr); - - cast_fp8_1D_kernel<<>>( - input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) - ); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, - Tensor *workspace, cudaStream_t stream) { - checkCuDriverContext(stream); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); - const size_t blocks_Y = chunks_Y; - const size_t blocks_X = chunks_X; - - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(FP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - } - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); - - cast_fp8_2D_kernel - <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, - workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} -#endif // #ifndef __HIP_PLATFORM_AMD__ - -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, - const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { -#ifndef __HIP_PLATFORM_AMD__ - using namespace mxfp8_kernel; - checkCuDriverContext(stream); -#endif - - bool use_rowwise_scaling = output->has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - - if (use_rowwise_scaling) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - -#ifdef __HIP_PLATFORM_AMD__ - constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; - constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; - constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; -#else - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - - constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; - - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -#endif - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - -#ifndef __HIP_PLATFORM_AMD__ - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } -#endif - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, -#ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ))); // NOLINT(*) -#else // #ifdef __HIP_PLATFORM_AMD__ - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } -#endif // #ifdef __HIP_PLATFORM_AMD__ - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace detail { - -using Empty = transformer_engine::Empty; - -__device__ inline float identity(float value, const Empty &) { return value; } - -struct DequantizeParam { - const float *scale_inv; -}; - -__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); -} - -} // namespace detail - -/* HIPCC has strict rules for __device__ functions usage on host. - It forbids not only calling but also other ODR-use assigning to variables - https://github.com/llvm/llvm-project/issues/105825 - Use templated struct wrapper to work around - */ -template -struct ActivationType -{ - static constexpr auto op = OP; -}; - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; -#else //#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; -#endif //#ifdef __HIP_PLATFORM_AMD__ - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(noop->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; -#else //#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; -#endif //#ifdef __HIP_PLATFORM_AMD__ - const size_t N = product(input->data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input->data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace { - -#ifndef __HIP_PLATFORM_AMD__ -static bool is_full_tile_1D_tensor(const Tensor *const t) { - const size_t N = product(t->data.shape); - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - return isFullTile; -} - -bool dimensions_supported_by_TMA(const Tensor *const t) { - const size_t cols = t->flat_last_dim(); - constexpr size_t TMA_bytes = 16; - const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); - return cols % alignment_requirement == 0; -} -#endif //#ifndef __HIP_PLATFORM_AMD__ - -} // namespace - -#ifndef __HIP_PLATFORM_AMD__ -// Supported by the Arch >= 10.0 -template -void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 - cast_fp8_1D(input, output, stream); - } else { - // Unaligned - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 (+dAct) - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } else { - // Unaligned - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - } else { - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -// Supported by the Arch < 10.0 -template -void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { - // zhongboz: should we just ignore IS_ACT here? - NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); - } - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} -#endif //#ifndef __HIP_PLATFORM_AMD__ - -template -void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, - Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - -#ifndef __HIP_PLATFORM_AMD__ - // NVIDIA - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } -#else - // AMD - fp8_quantize_rocm(input, act_input, noop, output, - dbias, workspace, stream); -#endif //#ifndef __HIP_PLATFORM_AMD__ -} - -namespace detail { - -template -void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, - NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - const Tensor *input_tensor; - const Tensor *activation_input_tensor; - if constexpr (IS_DBIAS || IS_DACT) { - // backward - input is incoming gradient - input_tensor = convertNVTETensorCheck(grad); - activation_input_tensor = convertNVTETensor(input); - } else { - // forward = input is activation input - input_tensor = convertNVTETensorCheck(input); - activation_input_tensor = nullptr; - } - auto output_tensor = convertNVTETensorCheck(output); - auto dbias_tensor = convertNVTETensor(dbias); - auto workspace_tensor = convertNVTETensor(workspace); - - const QuantizationConfig *quant_config_cpp = - reinterpret_cast(quant_config); - - // extract noop tensor from quant_config_cpp if it's not null - const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); - - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } -#ifndef __HIP_PLATFORM_AMD__ - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor.data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor.data, stream); - break; - } -#endif - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } -} - -} // namespace detail -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c83322f93..09187069e 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -18,7 +18,9 @@ #endif // __HIP_PLATFORM_AMD__ #include +#ifndef __HIP_PLATFORM_AMD__ #include "nccl.h" +#endif //#ifndef __HIP_PLATFORM_AMD__ #ifdef NVTE_WITH_CUBLASMP #include @@ -123,6 +125,7 @@ #endif // NVTE_WITH_CUBLASMP +#ifndef __HIP_PLATFORM_AMD__ #define NVTE_CHECK_NCCL(expr) \ do { \ const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ @@ -130,5 +133,5 @@ NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ } \ } while (false) - +#endif //#ifndef __HIP_PLATFORM_AMD__ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 4f0b888c5..ef53c2670 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -21,11 +21,15 @@ #endif // CUDA_VERSION >= 12080 #include "common/utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "../util/vectorized_pointwise.h" +#endif //#ifndef __HIP_PLATFORM_AMD__ namespace transformer_engine { namespace ptx { +#ifndef __HIP_PLATFORM_AMD__ template struct ArchSpecific { constexpr static int id = N * 10; @@ -125,6 +129,8 @@ constexpr bool is_supported_arch() { #define ARCH_HAS_STOCHASTIC_ROUNDING \ NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) +#endif //#ifndef __HIP_PLATFORM_AMD__ + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) @@ -259,26 +265,8 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { } __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { -<<<<<<< HEAD -#ifdef __HIP_PLATFORM_AMD__ -#define __CUDA_ARCH_HAS_FEATURE__(x) 0 -#endif -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; -======= +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool is_blackwell = false; constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; if constexpr (is_blackwell) { uint16_t out; @@ -290,6 +278,7 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { : "f"(val)); return *reinterpret_cast(&out); } else { +#endif //#ifndef __HIP_PLATFORM_AMD__ // TODO: nan/inf needs to be set for any value // of nan/inf in input not just amax. if (isnan(val)) { @@ -309,8 +298,9 @@ __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { ++exponent; } return exponent; ->>>>>>> 389a6b +#ifndef __HIP_PLATFORM_AMD__ } +#endif //#ifndef __HIP_PLATFORM_AMD__ } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -328,6 +318,38 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +#ifdef __HIP_PLATFORM_AMD__ +template +__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); + + T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); + + shared_loader.load(l_x_vec, chunk_dim_x); + + if (g_row < total_rows) { + global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; + global_storer.store(g_col_primitive_start / N_VEC, total_cols); + } + } +} +#endif //#ifdef __HIP_PLATFORM_AMD__ + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( @@ -909,6 +931,47 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#ifdef __HIP_PLATFORM_AMD__ +// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. +template +__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + if (g_row < total_rows) { + const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); + + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); + + global_loader.load(g_col_primitive_start / N_VEC, total_cols); + shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; + shared_storer.store(l_x_vec, chunk_dim_x); + + } else { + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); + +#pragma unroll + for (int i = 0; i < N_VEC; ++i) { + shared_storer.separate()[i] = static_cast(0); + } + shared_storer.store(l_x_vec, chunk_dim_x); + } + } +} +#else __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, const size_t chunk_Y, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { @@ -929,6 +992,7 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +#endif //#ifdef __HIP_PLATFORM_AMD__ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, diff --git a/transformer_engine/common/util/rocm_vectorized_2d.cuh b/transformer_engine/common/util/rocm_vectorized_2d.cuh index 5877ddd87..eda0f437f 100644 --- a/transformer_engine/common/util/rocm_vectorized_2d.cuh +++ b/transformer_engine/common/util/rocm_vectorized_2d.cuh @@ -9,73 +9,5 @@ #include "../util/vectorized_pointwise.h" namespace transformer_engine { -// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. -template -__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - if (g_row < total_rows) { - const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); - - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); - - global_loader.load(g_col_primitive_start / N_VEC, total_cols); - shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; - shared_storer.store(l_x_vec, chunk_dim_x); - - } else { - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); - -#pragma unroll - for (int i = 0; i < N_VEC; ++i) { - shared_storer.separate()[i] = static_cast(0); - } - shared_storer.store(l_x_vec, chunk_dim_x); - } - } -} - -template -__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); - - T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); - - shared_loader.load(l_x_vec, chunk_dim_x); - - if (g_row < total_rows) { - global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; - global_storer.store(g_col_primitive_start / N_VEC, total_cols); - } - } -} } // namespace transformer_engine diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 7e150ed6f..c56242d34 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -53,8 +53,6 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// -<<<<<<< HEAD -======= // Device-side error #define NVTE_DEVICE_ERROR(message) \ do { \ @@ -88,7 +86,6 @@ inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) //////////////////////////////////////////////////////////////////////////////////////////////////// ->>>>>>> 389a6b template struct Sum { inline __device__ Sum() {} From 0519b4ba1298f7b599c7a7e8330fb88dbdf4a9bb Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 10 Feb 2026 11:07:06 -0600 Subject: [PATCH 130/141] [ROCm] resolve the conflicts on jax side --- build_tools/jax.py | 2 +- .../jax/cpp_extensions/activation.py | 37 +--------- transformer_engine/jax/cpp_extensions/base.py | 31 ++------- transformer_engine/jax/cpp_extensions/misc.py | 6 -- .../jax/cpp_extensions/normalization.py | 8 +-- .../jax/cpp_extensions/quantization.py | 18 ----- transformer_engine/jax/csrc/extensions.h | 5 +- .../jax/csrc/extensions/attention.cpp | 67 ------------------- .../jax/csrc/extensions/cgemm_helper.cpp | 4 ++ .../jax/csrc/extensions/gemm.cpp | 18 +++-- transformer_engine/jax/csrc/extensions/misc.h | 2 + .../jax/csrc/extensions/pybind.cpp | 16 ++--- transformer_engine/jax/quantize/helper.py | 31 +-------- transformer_engine/jax/setup.py | 6 +- 14 files changed, 41 insertions(+), 210 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index 7886b8ba2..e67036f49 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -105,7 +105,7 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, - libraries=["nccl"], + libraries=["nccl"] if not rocm_build() else [], ) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 5897c2a74..df148265d 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,14 +11,8 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -======= from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING ->>>>>>> 389a6b from jax.sharding import PartitionSpec import numpy as np @@ -579,15 +573,6 @@ def shardy_sharding_rule( value_types, result_types, ): -<<<<<<< HEAD - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) - scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 -======= del ( out_dtype, act_enum, @@ -600,7 +585,6 @@ def shardy_sharding_rule( is_outer, mesh, result_types, ->>>>>>> 389a6b ) prefix = "ActLu" input_shape = value_types[0].shape @@ -1134,25 +1118,6 @@ def shardy_sharding_rule( value_types, result_types, ): -<<<<<<< HEAD - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" - scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 - ) - x_axes = scale_rules.input_spec - dz_axes = (*x_axes[:-2], x_axes[-1]) - out = x_axes - colwise_out = (prefix + "out_colwise",) - if is_2x: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) - else: - colwise_out = out -======= ->>>>>>> 389a6b del ( out_dtype, diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 6c53317af..176e0eadc 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -17,11 +17,7 @@ from jax._src import dispatch from jax import ffi -<<<<<<< HEAD from .misc import is_hip_extension -import jax -======= ->>>>>>> 389a6b import transformer_engine_jax @@ -193,24 +189,14 @@ def register_primitive(cls, outer_only=False): def name_of_wrapper_p(): return cls.name + "_wrapper" -<<<<<<< HEAD - inner_p = core.Primitive(cls.name) - dispatch.prim_requires_devices_during_lowering.add(inner_p) - inner_p.multiple_results = cls.multiple_results - inner_p.def_impl(partial(xla.apply_primitive, inner_p)) - inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") - cls.inner_primitive = inner_p -======= if not outer_only: inner_p = core.Primitive(cls.name) dispatch.prim_requires_devices_during_lowering.add(inner_p) inner_p.multiple_results = cls.multiple_results inner_p.def_impl(partial(xla.apply_primitive, inner_p)) inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="cuda") + mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") cls.inner_primitive = inner_p ->>>>>>> 389a6b outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) @@ -219,16 +205,11 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) - if version.parse(jax.__version__) >= version.parse("0.5.0"): - outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, - partition=cls.partition, - sharding_rule=cls.shardy_sharding_rule, - ) - else: - outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition - ) + outer_p_lower.def_partition( + infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition, + sharding_rule=cls.shardy_sharding_rule, + ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) ) diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 9443262c8..6c4be68ec 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -74,16 +74,10 @@ def jax_dtype_to_te_dtype(jax_dtype): jnp.bfloat16.dtype: TEDType.kBFloat16, jnp.int32.dtype: TEDType.kInt32, jnp.int64.dtype: TEDType.kInt64, -<<<<<<< HEAD get_jnp_float8_e4m3_type().dtype: TEDType.kFloat8E4M3, get_jnp_float8_e5m2_type().dtype: TEDType.kFloat8E5M2, - jnp.uint8.dtype: TEDType.kByte, -======= - jnp.float8_e4m3fn.dtype: TEDType.kFloat8E4M3, - jnp.float8_e5m2.dtype: TEDType.kFloat8E5M2, jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0, jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1, ->>>>>>> 389a6b } if jax_dtype not in converter: diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 37b5b077b..1bf6ec943 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,14 +12,8 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -======= from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING ->>>>>>> 389a6b from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec from .misc import is_hip_extension diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index fd7a101c8..bd2176170 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -12,14 +12,8 @@ import jax import jax.numpy as jnp -<<<<<<< HEAD -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule -======= from jax import dtypes, ffi from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING ->>>>>>> 389a6b from jax.sharding import PartitionSpec import transformer_engine_jax @@ -639,17 +633,6 @@ def shardy_sharding_rule( value_types, result_types, ): -<<<<<<< HEAD - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, scale_dtype, is_outer, mesh, result_types - - prefix = "BaseDBiasQuantizePrimitive_" - scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), - unique_var=prefix + "x", - flatten_axis=flatten_axis, -======= del ( out_dtype, scale_dtype, @@ -658,7 +641,6 @@ def shardy_sharding_rule( use_rht, mesh, result_types, ->>>>>>> 389a6b ) prefix = "DBiasQuantize" diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 396d7c089..845176080 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -17,7 +17,9 @@ #include #include #include +#ifndef USE_ROCM #include +#endif #include #include @@ -143,14 +145,11 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); -<<<<<<< HEAD #ifndef USE_ROCM -======= // Amax XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); ->>>>>>> 389a6b // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index a3c1a262b..1281eb272 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -522,72 +522,6 @@ static void FusedAttnBackwardImpl( auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { -<<<<<<< HEAD - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); - } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); -======= // QKV packed in q: [batch*seqlen, 3, heads, dim] NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); NVTE_CHECK(qk_head_dim == v_head_dim, @@ -614,7 +548,6 @@ static void FusedAttnBackwardImpl( dv_ptr = static_cast(static_cast(dk) + stride); // V has same shape as K since they're packed together v_shape = k_shape; ->>>>>>> 389a6b } auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp index 7082bfb03..4d44bb4a8 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -1,9 +1,12 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#ifndef USE_ROCM #include "cgemm_helper.h" #include "common/util/system.h" @@ -257,3 +260,4 @@ CommunicatorHandler::~CommunicatorHandler() { } // namespace jax } // namespace transformer_engine +#endif //#ifndef USE_ROCM diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 8f6383c0f..41b78f117 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -15,13 +15,17 @@ #include #include "../extensions.h" +#ifndef USE_ROCM #include "cgemm_helper.h" +#endif //#ifndef USE_ROCM #include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/string.h" #include "common/util/system.h" #include "cuda_runtime.h" +#ifndef USE_ROCM #include "nccl.h" +#endif //#ifndef USE_ROCM #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -98,6 +102,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( return std::make_tuple(std::move(input), input_shape); } +#ifndef USE_ROCM Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, @@ -162,6 +167,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, .Attr("grad") .Attr("use_split_accumulator") .Attr("collective_op")); +#endif //#ifndef USE_ROCM Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, @@ -273,6 +279,10 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, out_.data() /*D*/, workspace_.data(), config, stream); } else { +#ifdef USE_ROCM + //TODO: better assert + std::cerr<<"ROCm TE jax does not integrate userbuffer for now"< buffer_shape{0, 0}; DType buffer_dtype = out_dtype; auto &comm_handler = CommunicatorHandler::get(); @@ -318,6 +328,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, workspace_, grad, false, use_split_accumulator, aux_out_, stream); } +#endif //#ifdef USE_ROCM } return ffi_with_cuda_error_check(); @@ -346,14 +357,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") -<<<<<<< HEAD - .Attr("use_split_accumulator"), - GemmFFI_CudaGraph_Traits); -======= .Attr("use_split_accumulator") .Attr("collective_op"), - FFI_CudaGraph_Traits); ->>>>>>> 389a6b + GemmFFI_CudaGraph_Traits); size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, int32_t *host_group_sizes) { diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index 21b50c1af..a0c5db5a8 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -128,6 +128,7 @@ enum class JAXX_Collective_Op : int64_t { REDUCE_SCATTER = 2, }; +#ifndef USE_ROCM static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { switch (op) { case JAXX_Collective_Op::ALL_GATHER: @@ -141,6 +142,7 @@ static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { break; } } +#endif } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 7998af062..bc47ef6bd 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -7,7 +7,9 @@ ************************************************************************/ #include "../extensions.h" +#ifndef USE_ROCM #include "cgemm_helper.h" +#endif //#ifndef USE_ROCM #include "common/util/cuda_runtime.h" namespace transformer_engine { @@ -78,12 +80,15 @@ pybind11::dict Registrations() { dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + // Amax + dict["te_rht_amax_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); #else // Normalization dict["te_norm_forward_ffi"] = EncapsulateFFI(NormForwardHandler); dict["te_norm_backward_ffi"] = EncapsulateFFI(NormBackwardHandler); -<<<<<<< HEAD // Attention dict["te_fused_attn_forward_ffi"] = EncapsulateFFI(FusedAttnForwardHandler); dict["te_fused_attn_backward_ffi"] = EncapsulateFFI(FusedAttnBackwardHandler); @@ -91,13 +96,6 @@ pybind11::dict Registrations() { dict["te_gemm_ffi"] = EncapsulateFFI(GemmHandler); dict["te_grouped_gemm_ffi"] = EncapsulateFFI(GroupedGemmHandler); #endif -======= - // Amax - dict["te_rht_amax_ffi"] = pybind11::dict( - pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), - pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); - ->>>>>>> 389a6b return dict; } @@ -121,8 +119,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); +#ifndef USE_ROCM m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); +#endif pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index ced397371..792173ed1 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -26,7 +26,6 @@ import jax.numpy as jnp from flax.core.frozen_dict import FrozenDict -<<<<<<< HEAD from ..util import is_hip_extension, get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type from transformer_engine_jax import DType @@ -35,10 +34,6 @@ get_cublasLt_version, get_cuda_version, ) -from transformer_engine.common import recipe -from transformer_engine.jax.sharding import global_shard_guard, MeshResource -======= -from transformer_engine_jax import DType, get_cublasLt_version, get_cuda_version from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -54,7 +49,6 @@ get_all_mesh_axes, with_sharding_constraint, ) ->>>>>>> 389a6b from .metadata import QuantizeMeta from .scaling_modes import ScalingMode @@ -102,16 +96,11 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ -<<<<<<< HEAD if is_hip_extension(): if gpu_arch in [94, 95]: return True, "" else: return False, "Device arch gfx94x or gfx95x required for FP8 execution." - if gpu_arch >= 90: # hopper and above - return True, "" -======= ->>>>>>> 389a6b if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: @@ -130,13 +119,8 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: Returns: A tuple of (bool, str) indicating support and any error message """ -<<<<<<< HEAD if is_hip_extension(): return False, "FP8 block scaled gemm not yet supported for ROCm" - if gpu_arch >= 100: # blackwell and above - return True, "" -======= ->>>>>>> 389a6b if gpu_arch < 99: # pre-blackwell return False, "Device compute capability 9.9 or higher required for MXFP8 execution." if get_cublasLt_version() < 120800: @@ -259,23 +243,14 @@ def _format2dtypes(format_: Format): Returns: A tuple of (forward_dtype, backward_dtype) for the given format """ -<<<<<<< HEAD - if format_ == recipe.Format.E4M3: - return get_jnp_float8_e4m3_type(), get_jnp_float8_e4m3_type() - if format_ == recipe.Format.E5M2: - return get_jnp_float8_e5m2_type(), get_jnp_float8_e5m2_type() - if format_ == recipe.Format.HYBRID: - return get_jnp_float8_e4m3_type(), get_jnp_float8_e5m2_type() -======= if format_ == Format.E4M3: - return jnp.float8_e4m3fn, jnp.float8_e4m3fn + return get_jnp_float8_e4m3_type(), get_jnp_float8_e4m3_type() if format_ == Format.E5M2: - return jnp.float8_e5m2, jnp.float8_e5m2 + return get_jnp_float8_e5m2_type(), get_jnp_float8_e5m2_type() if format_ == Format.HYBRID: - return jnp.float8_e4m3fn, jnp.float8_e5m2 + return get_jnp_float8_e4m3_type(), get_jnp_float8_e5m2_type() if format_ == Format.E2M1: return jnp.float4_e2m1fn, jnp.float4_e2m1fn ->>>>>>> 389a6b return jnp.bfloat16, jnp.bfloat16 diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 589a96470..0b958d3ad 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -46,12 +46,8 @@ from build_tools.build_ext import get_build_ext -<<<<<<< HEAD -from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, - clear_hipify_tools_copy) -======= from build_tools.utils import copy_common_headers, min_python_version_str ->>>>>>> 389a6b +from build_tools.utils import rocm_build, copy_hipify_tools, clear_hipify_tools_copy from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements From 8f4b04db1d2d9debf0760eb27293b82b3a5cbb39 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 10 Feb 2026 11:08:54 -0600 Subject: [PATCH 131/141] [ROCm] resolve the conflicts on pytorch side --- .../dot_product_attention/backends.py | 4 - .../dot_product_attention/context_parallel.py | 53 +- .../pytorch/cpp_extensions/fused_attn.py | 21 +- transformer_engine/pytorch/csrc/common.cpp | 5 +- transformer_engine/pytorch/csrc/common.h | 4 +- .../pytorch/csrc/extensions/gemm.cpp | 16 +- transformer_engine/pytorch/csrc/util.cpp | 5 +- transformer_engine/pytorch/csrc/util.h | 5 +- transformer_engine/pytorch/fp8.py | 1093 +---------------- transformer_engine/pytorch/module/_common.py | 9 +- transformer_engine/pytorch/module/base.py | 16 +- .../pytorch/module/layernorm_linear.py | 10 +- .../pytorch/module/layernorm_mlp.py | 21 +- transformer_engine/pytorch/module/linear.py | 13 +- transformer_engine/pytorch/ops/fuser.py | 4 - transformer_engine/pytorch/quantization.py | 30 +- .../pytorch/quantized_tensor.py | 74 -- transformer_engine/pytorch/setup.py | 6 +- .../pytorch/tensor/_quantization_helpers.py | 1 + .../pytorch/tensor/float8_tensor.py | 26 +- .../pytorch/tensor/mxfp8_tensor.py | 42 +- .../pytorch/triton/cross_entropy.py | 211 ---- transformer_engine/pytorch/utils.py | 31 +- 23 files changed, 95 insertions(+), 1605 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index b455a0bd6..5437b73bc 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -15,12 +15,8 @@ from packaging.version import Version as PkgVersion import torch -<<<<<<< HEAD from torch.utils.cpp_extension import IS_HIP_EXTENSION - -======= import torch.nn.functional as F ->>>>>>> 389a6b import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( get_device_compute_capability, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c5e81516a..13b41345b 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1191,43 +1191,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, -<<<<<<< HEAD - ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) - - if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] - - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha - if is_input_fp8: - QKV_quantizer = q._quantizer - q, k, v = q._data, k._data, v._data - else: - q_f16, k_f16, v_f16 = q, k, v - if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16)._data - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - # partial result quantizer - for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() - O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) - else: - assert False, "FP8 is only supported with Fused Attention!" - else: - q_f16 = q - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] -======= ) = dpa_utils.get_attention_quantizers(fp8, quantizers) ->>>>>>> 389a6b q_f16 = None q_fp8, k_fp8, v_fp8 = (None, None, None) @@ -1293,7 +1257,7 @@ def forward( # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype q_f16 = q if use_fused_attention: - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] if return_max_logit: max_logit_per_step = [ torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size) @@ -2080,14 +2044,8 @@ def backward(ctx, dout, *_args): ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: -<<<<<<< HEAD - fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] -======= bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] ->>>>>>> 389a6b + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: @@ -3527,13 +3485,8 @@ def backward(ctx, dout, *_args): dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} -<<<<<<< HEAD - fused_attn_dqkv_dtype = TE_DType[dout_dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] -======= dqkv_te_dtype = TE_DType[dout.dtype] - fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen"] ->>>>>>> 389a6b + fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index b4a6e7ba8..852dcdb59 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -90,7 +90,12 @@ "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, } -<<<<<<< HEAD +SoftmaxType = { + "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, + "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, +} + if not IS_HIP_EXTENSION: FusedAttnBackend = { "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, @@ -104,20 +109,6 @@ "CK": NVTE_Fused_Attn_Backend.NVTE_CK, "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, } -======= -SoftmaxType = { - "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, - "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, - "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, -} - -FusedAttnBackend = { - "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, - "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, - "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, - "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, -} ->>>>>>> 389a6b BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index 2edb210ef..e1a78d49a 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -312,7 +312,6 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ inline bool nvte_use_atomic_amax() { @@ -336,7 +335,6 @@ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor) { #endif -======= void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { NVTE_SCOPED_GIL_RELEASE({ nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, @@ -353,5 +351,4 @@ at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_pe return philox_args; } ->>>>>>> 389a6b } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index a98818c88..74852b22d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -505,11 +505,10 @@ size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); -<<<<<<< HEAD #ifdef __HIP_PLATFORM_AMD__ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor); #endif -======= + std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); // unpack the PhiloxCudaState into CUDA tensor @@ -518,7 +517,6 @@ void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); // extract PhiloxCudaState from CUDA random number generator at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); ->>>>>>> 389a6b } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index cf99c2256..4a438d366 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -219,9 +219,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); -<<<<<<< HEAD -#ifndef USE_ROCM -======= // Construct GEMM config transformer_engine::MatmulConfigWrapper config; if (grad) { @@ -235,7 +232,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans config.set_use_split_accumulator(use_split_accumulator); config.set_sm_count(num_math_sms); ->>>>>>> 389a6b +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; #endif @@ -246,7 +243,6 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); -#endif // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer // as it is not natively supported by cublasLt @@ -260,6 +256,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans transa = true; transb = false; } +#endif if (comm_overlap) { #ifndef USE_ROCM @@ -494,17 +491,9 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } -<<<<<<< HEAD #ifndef USE_ROCM - // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); -#endif -======= // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; ->>>>>>> 389a6b // Optionally swizzle the scaling factors swizzled_scale_inverses_list.emplace_back( @@ -544,6 +533,7 @@ std::optional> te_general_grouped_gemm( transb = false; } } +#endif std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, te_pre_gelu_out_vector; diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 9d25f67df..3948c6403 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -190,9 +190,6 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } -<<<<<<< HEAD -#endif //!USE_ROCM -======= at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, bool rowwise) { using namespace transformer_engine::pytorch; @@ -261,4 +258,4 @@ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapp input = std::move(output_cu); return swizzled_scale_inv; } ->>>>>>> 389a6b +#endif //!USE_ROCM diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 1305c9afc..9a46ae86d 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -31,9 +31,6 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); -<<<<<<< HEAD -#endif //!USE_ROCM -======= /*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. * * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid @@ -45,6 +42,6 @@ std::optional multi_tensor_swizzle_scaling_factors( */ at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, bool rowwise); ->>>>>>> 389a6b +#endif //!USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c364d5e45..b36302db2 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,11 +10,6 @@ # pylint: disable=wrong-import-position,unused-import -<<<<<<< HEAD -import torch -from torch.utils.cpp_extension import IS_HIP_EXTENSION -import transformer_engine_torch as tex -======= import warnings warnings.warn( @@ -29,7 +24,6 @@ # There are some users indirectly importing these classes # from fp8.py. This ensure backwards compatibility. # https://github.com/Lightning-AI/lightning-thunder/pull/2635. ->>>>>>> 389a6b from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -41,1090 +35,6 @@ CustomRecipe, ) -<<<<<<< HEAD -from .constants import dist_group_type -from .utils import get_device_compute_capability, get_torch_float8_e4m3_type, get_torch_float8_e5m2_type -from .jit import jit_fuser - -__all__ = ["fp8_autocast", "fp8_model_init"] - -def check_fp8_support() -> Tuple[bool, str]: - if IS_HIP_EXTENSION: - gpu_arch = get_device_compute_capability() - if gpu_arch in ((9, 4), (9, 5)): - return True, "" - else: - return False, "Device arch gfx94x or gfx95x required for FP8 execution." - else: - """Return if fp8 support is available""" - if get_device_compute_capability() >= (9, 0): # hopper and above - return True, "" - if get_device_compute_capability() < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if float(torch.version.cuda) < 12.1: - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - -def check_mxfp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - if IS_HIP_EXTENSION: - if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") == "0": - return False, "MXFP8 support is not enabled." - gpu_arch = get_device_compute_capability() - if gpu_arch == (9, 5): - return True, "" - return False, "Gfx95x is required for MXFP8 execution." - if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." - if get_device_compute_capability() >= (10, 0): # blackwell and above - return True, "" - return False, "Device compute capability 10.0 or higher required for MXFP8 execution." - - -def check_fp8_block_scaling_support() -> Tuple[bool, str]: - """Return if fp8 block scaling support is available""" - if IS_HIP_EXTENSION: - return False, "FP8 block scaled gemm not yet supported for ROCm" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): - return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." - - -def check_recipe_support(recipe: Recipe) -> None: - """Check if the given recipe is supported.""" - recipe_supported = True - unsupported_reason = "" - if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): - recipe_supported, unsupported_reason = check_fp8_support() - elif isinstance(recipe, Float8BlockScaling): - recipe_supported, unsupported_reason = check_fp8_block_scaling_support() - elif isinstance(recipe, MXFP8BlockScaling): - recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason - - -def get_default_fp8_recipe() -> Recipe: - """FP8 recipe with default args.""" - if IS_HIP_EXTENSION: - if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") != "2": - return DelayedScaling() - gpu_arch = get_device_compute_capability() - if gpu_arch == (9, 5): - return MXFP8BlockScaling() - return DelayedScaling() - if check_mxfp8_support()[0]: - return MXFP8BlockScaling() - if get_device_compute_capability() >= (12, 0): - # This is a temporary restriction until MXFP8 is supported for all gemm layouts. - return Float8CurrentScaling() - return DelayedScaling() - - -def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return get_torch_float8_e4m3_type() - return get_torch_float8_e5m2_type() - - -def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get max representible FP8 value.""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return Format.E4M3.value.max_fwd - return Format.E5M2.value.max_fwd - - -class FP8GlobalStateManager: - """Class to keep track of and manipulate the global - FP8 state at different stages of execution. - """ - - FP8_ENABLED = False - FP8_CALIBRATION = False - FP8_RECIPE = None - FP8_DISTRIBUTED_GROUP = None - FP8_PARAMETERS = False - HIGH_PRECISION_INIT_VAL = False - IS_FIRST_FP8_MODULE = False - FP8_GRAPH_CAPTURING = False - SKIP_FP8_REDUCTION_FOR_FSDP2 = False - FP8_AUTOCAST_DEPTH = 0 - global_amax_buffer = {} - global_amax_history_buffer = {} - global_scale_buffer = {} - fp8_tensors_recompute_buffer = [] - fp8_available = None - reason_for_no_fp8 = "" - autocast_arguments = {} - autocast_to_fp8_params = {} - fp8_param_to_autocast = {} - skip_fp8_weight_update_tensor = None - mxfp8_available = None - reason_for_no_mxfp8 = "" - fp8_block_scaling_available = None - reason_for_no_fp8_block_scaling = None - - @classmethod - def reset(cls) -> None: - """Reset the global state""" - cls.FP8_ENABLED = False - cls.FP8_CALIBRATION = False - cls.FP8_RECIPE = None - cls.FP8_DISTRIBUTED_GROUP = None - cls.FP8_PARAMETERS = False - cls.HIGH_PRECISION_INIT_VAL = False - cls.IS_FIRST_FP8_MODULE = False - cls.FP8_GRAPH_CAPTURING = False - cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_amax_buffer = {} - cls.global_amax_history_buffer = {} - cls.global_scale_buffer = {} - cls.fp8_tensors_recompute_buffer = [] - cls.fp8_available = None - cls.reason_for_no_fp8 = "" - cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} - cls.skip_fp8_weight_update_tensor = None - cls.mxfp8_available = None - cls.reason_for_no_mxfp8 = "" - cls.fp8_block_scaling_available = None - cls.reason_for_no_fp8_block_scaling = "" - - @classmethod - def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: - """`skip_fp8_weight_update_tensor` inplace setter.""" - if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") - cls.skip_fp8_weight_update_tensor.fill_(skip) - - @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> None: - """`skip_fp8_weight_update_tensor` getter.""" - return cls.skip_fp8_weight_update_tensor - - @classmethod - def is_fp8_available(cls) -> Tuple[bool, str]: - """Return if fp8 support is available""" - if cls.fp8_available is None: - cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() - return cls.fp8_available, cls.reason_for_no_fp8 - - @classmethod - def is_mxfp8_available(cls) -> Tuple[bool, str]: - """Return if MXFP8/current scaling support is available.""" - if cls.mxfp8_available is None: - cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() - return cls.mxfp8_available, cls.reason_for_no_mxfp8 - - @classmethod - def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: - """Return if Float8 block scaling support is available.""" - if cls.fp8_block_scaling_available is None: - cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( - check_fp8_block_scaling_support() - ) - return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling - - @staticmethod - def get_meta_tensor_key(forward: bool = True) -> str: - """Returns scaling key in `fp8_meta`.""" - if forward: - return "scaling_fwd" - return "scaling_bwd" - - @staticmethod - def get_fwd_bwd_key(forward: bool = True) -> str: - """Convert bool `forward` to string.""" - return "forward" if forward else "backward" - - @classmethod - def get_buffer_info(cls) -> str: - """ - Returns a key for `fp8_meta` that stores the module's index - in the global buffers along with autocast information. - """ - return "buffer_index_and_autocast_key" - - @classmethod - def get_key_in_buffer( - cls, - forward: bool, - fp8_recipe: Recipe, - fp8_group: dist_group_type, - ) -> str: - """Returns a key into the global FP8 buffers.""" - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{autocast_key}" - - @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: - """Splits buffer key into relevant parts.""" - forward, autocast_key = key.split("_", 1) - forward = forward == "forward" - return forward, autocast_key - - @classmethod - def add_fp8_tensors_to_global_buffer( - cls, - fp8_meta: Dict[str, Any], - ) -> None: - """ - Delayed scaling only. - - The amax reduction process happens completely outside the FP8 modules. - To participate in the reduction, the only role played by a module is - to call this function in order to append it's FP8 tensor into a global - buffer. There are 5 global buffers maintained, one each for amax, amax - history, scale, scale-inverse, and non-weight-mask. Each buffer has - keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix - to indicate the type of FP8 tensor, since the forward and backward - reductions happen separately. - - Note: For CG capture, this method is called from the graphed - wrapper. For non CG case, it's called from within the module. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Every module must call this function exactly once since - # the amax tensors are static. Ensures that compatibility - # with non-graphed modules is maintained. - index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. - if index_in_buffer in fp8_meta: - return - - fp8_meta[index_in_buffer] = [] - for forward in (True, False): - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - if fp8_meta_tensor_key not in fp8_meta: - # Handles non-parameter FP8 modules, e.g. DPA. - continue - - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - - if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history - ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) - fp8_meta[index_in_buffer].append(key) - - @classmethod - def is_fp8_enabled(cls) -> bool: - """Is FP8 enabled""" - return cls.FP8_ENABLED - - @classmethod - def is_fp8_calibration(cls) -> bool: - """Is FP8 calibration""" - return cls.FP8_CALIBRATION - - @classmethod - def with_fp8_parameters(cls) -> bool: - """Should the parameters be stored as FP8""" - return cls.FP8_PARAMETERS - - @classmethod - def with_high_precision_init_val(cls) -> bool: - """Should the high precision initial values be stored with FP8 parameters""" - return cls.HIGH_PRECISION_INIT_VAL - - @classmethod - def fp8_graph_capturing(cls) -> bool: - """Is CUDA graph capture under way?""" - return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() - - @classmethod - def is_first_fp8_module(cls): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = cls.IS_FIRST_FP8_MODULE - cls.IS_FIRST_FP8_MODULE = False - return tmp - - @classmethod - def get_fp8_recipe(cls) -> Recipe: - """Return the fp8 recipe""" - if cls.FP8_RECIPE is not None: - return cls.FP8_RECIPE - return get_default_fp8_recipe() - - @classmethod - def get_fp8_group(cls) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return cls.FP8_DISTRIBUTED_GROUP - - @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: - """FP8 autocast state getter""" - return ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) - - @classmethod - def set_fp8_autocast_state( - cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] - ) -> None: - """FP8 autocast state setter""" - ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) = fp8_state - - @staticmethod - def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: - """Reduce tensor across given group.""" - if torch.distributed.is_initialized(): - torch.distributed.all_reduce( - tensor, - op=torch.distributed.ReduceOp.MAX, - group=group, - async_op=False, - ) - - @classmethod - def reduce_and_update_fp8_tensors( - cls, - forward: bool = True, - ) -> None: - """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" - # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in cls.global_amax_buffer.items(): - # Check for forward or backward reduction. - fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) - if fwd_update != forward: - continue - if len(amax_buffer) == 0: - continue - - # Retrieve autocast specific args and concat amaxes. - recipe, group = cls.autocast_arguments[autocast_key] - contiguous_amax = torch.cat(amax_buffer) - - # Reduction. - if ( - recipe.reduce_amax - and torch.distributed.is_initialized() - and torch.distributed.get_world_size(group=group) > 1 - ): - cls.reduce_tensor_across_group_op_max(contiguous_amax, group) - - # Amax and scale update. - unfused_update = ( - bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) - or callable(recipe.amax_compute_algo) - or callable(recipe.scaling_factor_compute_algo) - ) - - if not unfused_update: - tex.fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - recipe.amax_compute_algo, - get_fp8_te_dtype(recipe, forward), - recipe.margin, - ) - else: - split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - - for amax_history, scale in zip( - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - ): - _amax_and_scale_update( - amax_history, scale, get_fp8_max(recipe, forward), recipe - ) - - @classmethod - def get_unique_autocast_key( - cls, - recipe: Optional[Recipe] = None, - group: Optional[dist_group_type] = None, - ): - """ - For FP8, each autocast can be uniquely identified by the recipe and fp8 group. - Safely using `hash` as we never cross checkpoint boundaries. - """ - return f"{str(recipe)}:{hash(group)}" - - @classmethod - def fp8_autocast_enter( - cls, - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, - ) -> None: - """Set state and tracking variables for entry into FP8 region.""" - - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - - cls.FP8_ENABLED = enabled - cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = fp8_recipe - cls.FP8_DISTRIBUTED_GROUP = fp8_group - cls.FP8_GRAPH_CAPTURING = _graph - - if cls.FP8_AUTOCAST_DEPTH == 0: - cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_DEPTH += 1 - - if enabled: - fp8_available, reason_for_no_fp8 = cls.is_fp8_available() - assert fp8_available, reason_for_no_fp8 - if isinstance(fp8_recipe, MXFP8BlockScaling): - mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() - assert mxfp8_available, reason_for_no_mxfp8 - if isinstance(fp8_recipe, Float8BlockScaling): - fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() - assert fp8_block_available, reason_for_no_fp8_block - - @classmethod - def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: - """Set state and tracking variables for exit from FP8 region.""" - cls.FP8_AUTOCAST_DEPTH -= 1 - # Reduce only the non-FP8 weight modules here. - # FP8 weight modules are reduced at the end of the optimizer - # step after the weight amax is populated. - if not cls.SKIP_FP8_REDUCTION_FOR_FSDP2 and enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - # delayed scaling only function, for other recipes (current scaling with any granularity), - # this is noop for other recipes because cls.global_amax_buffer is empty list - cls.reduce_and_update_fp8_tensors(forward=True) - - @classmethod - def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Copy the scaling factors and amaxes for recompute forward phase - to ensure both forward steps are numerically same. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - ] - - if buffer_position_key in fp8_meta: - cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) - else: - if len(cls.fp8_tensors_recompute_buffer) == 0: - cls.fp8_tensors_recompute_buffer = [deque()] - else: - cls.fp8_tensors_recompute_buffer.append(deque()) - cls.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 - - @classmethod - def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Switch to the copied scaling factors and amaxes from phase - 1 forward for indentical numerical outputs. - """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) - fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) - fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - - -@contextmanager -def fp8_model_init( - enabled: bool = True, - recipe: Optional[Recipe] = None, - preserve_high_precision_init_val: bool = False, -) -> None: - """ - Context manager for FP8 initialization of parameters. - - Example usage: - - .. code-block:: python - - with fp8_model_init(enabled=True): - model = transformer_engine.pytorch.Linear(768, 768) - - # Preserving high precision initial value to initialize master weight - with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): - model = transformer_engine.pytorch.Linear(768, 768) - master_weight = model.weight.get_high_precision_init_val() - model.weight.clear_high_precision_init_val() - - Parameters - ---------- - enabled: bool, default = `True` - when enabled, Transformer Engine modules created inside this `fp8_model_init` - region will hold only FP8 copies of its parameters, as opposed to the default - behavior where both higher precision and FP8 copies are present. Setting this - option to `True` may result in lower memory consumption and is especially - useful for scenarios like: - - * full model training using optimizer with master weights, where the high - precision copies of weights are already present in the optimizer. - * inference, where only the FP8 copies of the parameters are used. - * LoRA-like fine-tuning, where the main parameters of the model do not change. - recipe: transformer_engine.common.recipe.Recipe, default = `None` - Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. - preserve_high_precision_init_val: bool, default = `False` - when enabled, store the high precision tensor used to initialize FP8 parameters - in CPU memory, and add two function attributes named `get_high_precision_init_val()` - and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high - precision tensor. The purpose is that users can use this high-precision copy - to initialize master weights, avoiding the loss of precision that can occur when - using FP8 parameters directly. Note that after the master weights are initialized, - users should call `clear_high_precision_init_val()` to release this CPU memory. - - This functionality is *EXPERIMENTAL*. - """ - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE - _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL - FP8GlobalStateManager.FP8_PARAMETERS = enabled - FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val - try: - yield - finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters - FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val - - -@contextmanager -def fp8_autocast( - enabled: bool = True, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, -) -> None: - """ - Context manager for FP8 usage. - - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `True` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.Recipe, default = `None` - recipe used for FP8 training. - fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - if enabled: - check_recipe_support(fp8_recipe) - fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter( - enabled=enabled, - calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, - _graph=_graph, - ) - try: - yield - finally: - FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) - FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) - - -def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: - """Update amax history and set next amax to zero.""" - if amax_history.shape[0] > 1: - new_amax_history = torch.roll(amax_history, -1, 0) - amax_history.copy_(new_amax_history) - amax_history[0].fill_(0.0) - return amax_history - - -@torch.jit.script -def _default_get_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Default function to obtain amax from history.""" - if amax_compute_algo == "max": - amax = torch.max(amax_history, dim=0).values - else: # amax_compute_algo == "most_recent" - amax = amax_history[0].clone() - - amax_history = _update_amax_history(amax_history) - return amax_history, amax - - -@jit_fuser -def _default_sf_compute( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - margin: int, - _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter -) -> torch.Tensor: - """Default function to convert amax to scaling factor. - Computing the scaling factor requires consideration of the following scenarios: - 1. amax == 0: - No action is possible, set scale to the previous scale (or 1). - 2. 0 < amax < tiny_amax - The amax is too tiny that the scale becomes infinite in FP32. - Set scale = FP32_max - 3. tiny_amax <= amax < FP32_max: - Set scale = FP8_max (or scaled_max) / amax - 4. When amax == inf or amax == nan: - No action is possible, set scale to the previous scale (or 1). - """ - sf = (fp8_max / amax) / (2**margin) - sf = torch.where(amax > 0.0, sf, scale) - sf = torch.where(torch.isfinite(amax), sf, scale) - sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) - scale.copy_(sf) - return scale - - -def _compute_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: Union[Callable, str], -) -> Tuple[torch.Tensor, torch.Tensor]: - """Obtain the amax from the history.""" - - if callable(amax_compute_algo): - amax = amax_compute_algo(amax_history) - amax_history = _update_amax_history(amax_history) - return amax_history, amax - return _default_get_amax_and_update_history( - amax_history, - amax_compute_algo, - ) - - -def _compute_scaling_factor( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> torch.Tensor: - """Convert amax to scaling factor.""" - - if recipe.scaling_factor_compute_algo is None: - return _default_sf_compute( - amax, - scale, - fp8_max, - recipe.margin, - ) - return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) - - -def _amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> None: - """Updates FP8 meta tensors.""" - new_amax_history, amax = _compute_amax_and_update_history( - amax_history, - recipe.amax_compute_algo, - ) - new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) - scale.copy_(new_scale) - amax_history.copy_(new_amax_history) - - -def split_and_copy( - buffer: torch.Tensor, - outputs: List[torch.Tensor], - chunk_sizes: List[int], -) -> None: - """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" - splits = buffer.split(chunk_sizes) - torch._foreach_copy_(outputs, splits) - - -class RecipeState(abc.ABC): - """Configuration and state for a quantization recipe. - - This is a builder class for quantizers, which are in turn builder - classes for quantized tensors. - - This class may pack together the state for multiple quantizers, - which is helpful for applying fused kernels with less overhead. - - """ - - @staticmethod - def create( - recipe: Recipe, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> RecipeState: - """Factory method to create the state for a quantization recipe - - Parameters - ---------- - recipe: Recipe - Quantization recipe. - mode: {"forward", "backward"} - Training stage where quantization will be performed. - num_quantizers: int, default = 1 - Number of quantizers to create state for. - device: torch.device, default = default CUDA device - Device for quantized tensors. - - Returns - ------- - RecipeState: - Quantization recipe state. - - """ - - cls = None - if recipe.delayed(): - cls = DelayedScalingRecipeState - elif recipe.mxfp8(): - cls = MXFP8BlockScalingRecipeState - elif recipe.float8_current_scaling(): - cls = Float8CurrentScalingRecipeState - elif recipe.float8_block_scaling(): - cls = Float8BlockScalingRecipeState - else: - raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( - recipe, - mode=mode, - num_quantizers=num_quantizers, - device=device, - ) - - @abc.abstractmethod - def make_quantizers(self) -> list: - """Convert recipe state to quantizers. - - Quantizers are builder classes for quantized tensors. They are - typically used to convert a high-precision tensor (e.g. in - FP32 or BF16) into a quantized tensor (e.g. in FP8). - - """ - - -class DelayedScalingRecipeState(RecipeState): - """State for FP8 quantization with per-tensor delayed scaling. - - Delayed scaling recipe requires a scaling factor (applied when - casting to FP8) and a history of max-abs values ("amax") from - recent FP8 casts for updating the scaling factor. The scale update - is handled externally by `FP8GlobalStateManager`. - - """ - - recipe: DelayedScaling - mode: str - dtype: tex.DType - scale: torch.Tensor - amax_history: torch.Tensor - - def __init__( - self, - recipe: DelayedScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) - self.amax_history = torch.zeros( - recipe.amax_history_len, - num_quantizers, - dtype=torch.float32, - device=device, - ) - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_tensor import Float8Quantizer - - return [ - Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) - for i in range(self.num_quantizers) - ] - - -class Float8CurrentScalingRecipeState(RecipeState): - """Configuration for Per-tensor current scaling quantization. - - Per-tensor current quantization does not require state. - - """ - - recipe: Float8CurrentScaling - mode: str - dtype: tex.DType - device: torch.device - - def __init__( - self, - recipe: Float8CurrentScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - from .tensor.float8_tensor import Float8CurrentScalingQuantizer - - return [ - Float8CurrentScalingQuantizer(self.dtype, device=self.device) - for i in range(self.num_quantizers) - ] - - -class MXFP8BlockScalingRecipeState(RecipeState): - """Configuration for MXFP8 quantization. - - MXFP8 quantization does not require state. - - """ - - recipe: MXFP8BlockScaling - mode: str - dtype: tex.DType - - def __init__( - self, - recipe: MXFP8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.mxfp8_tensor import MXFP8Quantizer - - return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] - - -class Float8BlockScalingRecipeState(RecipeState): - """Configuration for Float8BlockScaling quantization. - - Float8BlockScaling quantization does not require state, - but different quantizers use different modes. - """ - - recipe: Float8BlockScaling - mode: str - qx_dtype: tex.DType - qw_dtype: tex.DType - qgrad_dtype: tex.DType - - def __init__( - self, - recipe: Float8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.qx_dtype = get_fp8_te_dtype(recipe, True) - self.qw_dtype = get_fp8_te_dtype(recipe, True) - self.qgrad_dtype = get_fp8_te_dtype(recipe, False) - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_blockwise_tensor import Float8BlockQuantizer - - if self.mode == "forward": - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward, and doesn't play nicely with QuantizeOp, - # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qw_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 3) - ] - ) - ) - - assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 2) - ] - ) - ) -======= # Importing each function instead of 'import *' allows us specify '__all__' in # quantize.py and also makes any newer additions to quantize.py invisible via @@ -1158,4 +68,3 @@ def make_quantizers(self) -> list: NVFP4BlockScalingRecipeState, CustomRecipeState, ) ->>>>>>> 389a6b diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 591fa60c2..4ba5da68d 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -1,19 +1,13 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Internal function used by multiple modules.""" -<<<<<<< HEAD -import os -from typing import Any, List, Optional, Tuple, Union, Callable -from dataclasses import dataclass -======= import dataclasses ->>>>>>> 389a6b import queue from typing import Any, Callable, List, Optional, Tuple, Union @@ -28,6 +22,7 @@ if IS_HIP_EXTENSION: from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton + import os def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6a3bda0c6..661cf3f2e 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -19,11 +19,8 @@ import torch import torch.nn.functional as F -<<<<<<< HEAD -from torch.utils.cpp_extension import IS_HIP_EXTENSION -======= from torch.distributed.tensor import DTensor ->>>>>>> 389a6b +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe @@ -49,21 +46,13 @@ from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -<<<<<<< HEAD if IS_HIP_EXTENSION: from ..tensor.fsdp2_allgather_tensor import FSDPAGTensor -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton -from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -======= from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage ->>>>>>> 389a6b from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -1328,11 +1317,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: raise RuntimeError("Weight quantizer has not been initialized") quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) quantizer.internal = False -<<<<<<< HEAD if IS_HIP_EXTENSION and not self.keep_fp8_weight_transpose_cache: quantizer.columnwise_usage=False -======= if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): device_mesh = dtensor_param.device_mesh amax_reduction_group = ( @@ -1342,7 +1329,6 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: ) quantizer.amax_reduction_group = amax_reduction_group quantizer.with_amax_reduction = True ->>>>>>> 389a6b # Quantize parameter param = quantizer(param) if IS_HIP_EXTENSION and self.use_fsdp2 and not self.primary_weights_in_fp8 and fp8_meta_index is not None: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 0d1f3f5b0..a906ea42d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -305,16 +305,11 @@ def forward( is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) # Configure quantizer -<<<<<<< HEAD - if weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) -======= # If weight is already quantized, no need to set quantizer states if is_weight_param_quantized: weight_quantizer = weight._quantizer elif weight_quantizer is not None: - weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) ->>>>>>> 389a6b + weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch @@ -447,14 +442,11 @@ def forward( ): ln_out.update_usage(rowwise_usage=False) -<<<<<<< HEAD # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: if isinstance(weightmat, QuantizedTensorBase): weightmat.update_usage(columnwise_usage=True) -======= ->>>>>>> 389a6b if cpu_offloading: mark_activation_offload(inputmat, mu, rsigma, ln_out) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4a0397502..3fefb650e 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -371,22 +371,17 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch -<<<<<<< HEAD - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) -======= # No need to set the quantizer states if weights are already quantized if isinstance(fc1_weight, QuantizedTensorStorage): fc1_weight_quantizer = fc1_weight._quantizer elif fc1_weight_quantizer is not None: - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) if isinstance(fc2_weight, QuantizedTensorStorage): fc2_weight_quantizer = fc2_weight._quantizer elif fc2_weight_quantizer is not None: - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled) + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) ->>>>>>> 389a6b fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -579,8 +574,6 @@ def forward( # Cache state for backward pass if is_grad_enabled: -<<<<<<< HEAD - # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: if isinstance(fc1_weight_final, QuantizedTensorBase): @@ -588,8 +581,6 @@ def forward( if isinstance(fc2_weight_final, QuantizedTensorBase): fc2_weight_final.update_usage(columnwise_usage=True) -======= ->>>>>>> 389a6b if cpu_offloading: mark_activation_offload( inputmat, mu, rsigma, ln_out, fc1_out, fc1_out_without_bias, act_out @@ -906,11 +897,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( -<<<<<<< HEAD - fc2_weight, QuantizedTensorBase -======= ctx.fc2_weight, QuantizedTensorStorage ->>>>>>> 389a6b ): fc2_weight.update_usage(columnwise_usage=True) @@ -1168,11 +1155,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( -<<<<<<< HEAD - fc1_weight, QuantizedTensorBase -======= ctx.fc1_weight_quantizer, QuantizedTensorStorage ->>>>>>> 389a6b ): fc1_weight.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index e7d2cf8d8..ffa47d986 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -253,16 +253,10 @@ def forward( weightmat = weight if fp8 or debug: # Configure quantizer -<<<<<<< HEAD - if weight_quantizer is not None: - columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache - if not columnwise_usage and keep_fp8_weight_transpose_cache: -======= # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): - columnwise_usage = is_grad_enabled and inp.requires_grad - if not columnwise_usage: ->>>>>>> 389a6b + columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache + if not columnwise_usage and keep_fp8_weight_transpose_cache: columnwise_usage = ( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() @@ -414,14 +408,11 @@ def forward( if backward_needs_input: saved_inputmat = inputmat -<<<<<<< HEAD # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: if isinstance(weightmat, QuantizedTensorBase): weightmat.update_usage(columnwise_usage=True) -======= ->>>>>>> 389a6b if cpu_offloading and saved_inputmat is not None: mark_activation_offload(saved_inputmat) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index c1985e04c..7eb04fa27 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -29,16 +29,12 @@ fuse_forward_linear_bias_add, fuse_forward_linear_scale_add, ) -<<<<<<< HEAD if not IS_HIP_EXTENSION: from transformer_engine.pytorch.ops.fused import ( fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( -======= from transformer_engine.pytorch.quantized_tensor import ( ->>>>>>> 389a6b prepare_for_saving, restore_from_saved, ) diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 030370b9d..0b7eddb9f 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -31,6 +31,8 @@ from .utils import get_device_compute_capability from .jit import jit_fuser +from torch.utils.cpp_extension import IS_HIP_EXTENSION +from .utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type __all__ = [ "autocast", @@ -46,6 +48,12 @@ @functools.lru_cache(maxsize=None) def check_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" + if IS_HIP_EXTENSION: + gpu_arch = get_device_compute_capability() + if gpu_arch in ((9, 4), (9, 5)): + return True, "" + else: + return False, "Device arch gfx94x or gfx95x required for FP8 execution." if get_device_compute_capability() >= (9, 0): # hopper and above return True, "" if get_device_compute_capability() < (8, 9): # pre-ada @@ -60,6 +68,13 @@ def check_fp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_mxfp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" + if IS_HIP_EXTENSION: + if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") == "0": + return False, "MXFP8 support is not enabled." + gpu_arch = get_device_compute_capability() + if gpu_arch == (9, 5): + return True, "" + return False, "Gfx95x is required for MXFP8 execution." if get_device_compute_capability() >= (12, 0): return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." if get_device_compute_capability() >= (10, 0): # blackwell and above @@ -69,6 +84,8 @@ def check_mxfp8_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_nvfp4_support() -> Tuple[bool, str]: + if IS_HIP_EXTENSION: + return False, "ROCm TE currently not supporting NVFP4" """Return if nvfp4 support is available""" if get_device_compute_capability() >= (10, 0): # blackwell and above return True, "" @@ -78,6 +95,8 @@ def check_nvfp4_support() -> Tuple[bool, str]: @functools.lru_cache(maxsize=None) def check_fp8_block_scaling_support() -> Tuple[bool, str]: """Return if fp8 block scaling support is available""" + if IS_HIP_EXTENSION: + return False, "FP8 block scaled gemm not yet supported for ROCm" if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: return True, "" return ( @@ -101,6 +120,13 @@ def check_recipe_support(recipe: Recipe) -> None: def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" + if IS_HIP_EXTENSION: + if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") != "2": + return DelayedScaling() + gpu_arch = get_device_compute_capability() + if gpu_arch == (9, 5): + return MXFP8BlockScaling() + return DelayedScaling() if check_mxfp8_support()[0]: return MXFP8BlockScaling() if get_device_compute_capability() >= (12, 0): @@ -119,8 +145,8 @@ def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch. if fp8_recipe.fp8_format == Format.E4M3 or ( fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor ): - return torch.float8_e4m3fn - return torch.float8_e5m2 + return get_torch_float8_e4m3_type() + return get_torch_float8_e5m2_type() def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 35bacb32a..dd6a7ebc5 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -331,86 +331,12 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-a """Returns whether or not given tensor can be quantized""" return True -<<<<<<< HEAD:transformer_engine/pytorch/tensor/quantized_tensor.py -class _QuantizeFunc(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: torch.Tensor, - quantizer: Quantizer, - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton - use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) - quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize - return quantize_func(tensor, quantizer) - else: - return tex.quantize(tensor, quantizer) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if constructor kwargs are not provided - if init_kwargs is None: - return tensor.detach() - - # Construct new tensor if constructor kwargs are provided - ctx.input_dtype = tensor.dtype - kwargs = tensor.get_metadata() - for key, val in init_kwargs.items(): - kwargs[key] = val - return type(tensor)(tensor.shape, tensor.dtype, **kwargs) - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - grad_input = grad_output - if grad_input.dtype == ctx.input_dtype: - grad_input = grad_input.detach() - else: - grad_input = grad_input.to(ctx.input_dtype) - return grad_input, None - - -def _stride_from_shape(shape: list[int]): - if len(shape) == 0: - return [] - rstride = [1] - for d in reversed(shape[1:]): - rstride.append(rstride[-1] * d) - return list(reversed(rstride)) -======= def get_usages(self) -> Dict[str, bool]: """Get the usage of the quantizer""" return { "rowwise": self.rowwise_usage, "columnwise": self.columnwise_usage, } ->>>>>>> 389a6b:transformer_engine/pytorch/quantized_tensor.py - class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index c59001e61..73f926c61 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -47,12 +47,8 @@ from build_tools.build_ext import get_build_ext -<<<<<<< HEAD -from build_tools.utils import ( - rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy ) -======= +from build_tools.utils import rocm_build, copy_hipify_tools, clear_hipify_tools_copy from build_tools.utils import copy_common_headers, min_python_version_str ->>>>>>> 389a6b from build_tools.te_version import te_version from build_tools.pytorch import ( setup_pytorch_extension, diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py index 2214edbff..55fc4785d 100644 --- a/transformer_engine/pytorch/tensor/_quantization_helpers.py +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -26,6 +26,7 @@ def forward( quantize_impl: Callable, ) -> QuantizedTensor: # pylint: disable=missing-function-docstring + # TODO: bring back triton based quantization return quantize_impl(tensor) @staticmethod diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 789d207b0..8f741b7f2 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -6,16 +6,9 @@ """Tensor class with FP8 data""" from __future__ import annotations -<<<<<<< HEAD -import os -from typing import Optional, Tuple, Iterable, Union -import warnings -from torch.utils.cpp_extension import IS_HIP_EXTENSION -======= from typing import Any, Optional, Tuple, Iterable, Union import warnings ->>>>>>> 389a6b import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex @@ -27,8 +20,11 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type + +from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton + import os aten = torch.ops.aten @@ -109,7 +105,13 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) def make_empty( self, @@ -304,7 +306,13 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) def make_empty( self, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 6008f0503..1848a60cf 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -8,22 +8,12 @@ from __future__ import annotations from collections.abc import Iterable import math -<<<<<<< HEAD -import os -from typing import Optional, Tuple, Union -from torch.utils.cpp_extension import IS_HIP_EXTENSION - -import torch -if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton -======= from typing import Optional, Tuple, Union, Any import warnings import torch from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState ->>>>>>> 389a6b import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType @@ -34,6 +24,11 @@ from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + import os + from ..triton_kernels.cast import te_quantize_triton + aten = torch.ops.aten @@ -89,7 +84,13 @@ def update_quantized( def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" - return tex.quantize(tensor, self) + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" @@ -124,41 +125,26 @@ def make_empty( ) # Allocate FP8 data -<<<<<<< HEAD - data = torch.empty(shape, dtype=torch.uint8, device=device) - # ROCm TE does not implement fuse padding zeros so use zero tensor here - scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) -======= data = None scale_inv = None if self.rowwise_usage: data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - scale_inv = torch.empty( + # ROCm TE does not implement fuse padding zeros so use zero tensor here + scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), dtype=torch.uint8, device=device, pin_memory=pin_memory, ) ->>>>>>> 389a6b # Allocate FP8 data transpose if needed columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: -<<<<<<< HEAD columnwise_data = torch.empty_like(data) # ROCm TE does not implement fuse padding zeros so use zero tensor here columnwise_scale_inv = torch.zeros( -======= - columnwise_data = torch.empty_like(data, pin_memory=pin_memory) - columnwise_scale_inv = torch.empty( ->>>>>>> 389a6b round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 815c2836c..498dd7cdd 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -14,218 +14,7 @@ import torch.distributed as dist import triton -<<<<<<< HEAD -import triton.language as tl from torch.utils.cpp_extension import IS_HIP_EXTENSION - - -@triton.jit -def online_softmax_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes the m/d components on this TP rank for the online softmax. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride (int): The stride of the m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) - else: - X_y = float("-inf") - else: - X_y = float("-inf") - - m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 - - # 3. [Online softmax] first pass: find max + sum - m = float("-inf") # m is the max value. use the notation from the paper - d = 0.0 # d is the sum. use the notation from the paper - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( - tl.float32 - ) - block_max = tl.max(X_block) - m_new = tl.maximum(m, block_max) - d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) - m = m_new - - tl.store(m_d_X_y_ptr, m) - tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) - tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) - - -@triton.jit -def cross_entropy_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - loss_ptr, - loss_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - world_size, - ignore_idx, - n_cols, - n_non_ignore, - reduce_loss: tl.constexpr, - label_smoothing: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes both cross entropy loss and the gradient of the input. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - loss_ptr: Pointer to tensor to store the loss. - loss_stride (int): The stride of the loss tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride: The stride of m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. - n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - if y == ignore_idx: - # set all X_ptr as 0 - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) - return - - loss_ptr += program_id * loss_stride - m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride - - # Need to reduce the m/d/X_y values from other TP ranks - m = tl.load(m_d_X_y_ptr) - d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) - ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) - - for i in range(1, world_size): - offset = i * 3 * n_non_ignore * m_d_X_y_stride - access_ptr = m_d_X_y_ptr + offset - m_new = tl.load(access_ptr) - d_new = tl.load(access_ptr + m_d_X_y_stride) - X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) - - d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) - m = tl.maximum(m, m_new) - ori_X_y = tl.maximum(ori_X_y, X_y_new) - - # Label smoothing is a general case of normal cross entropy - scaled_x_sum = 0.0 - eps = label_smoothing / (n_cols * world_size) - - # 4. [Online softmax] second pass: calculate the gradients - # dx_y = (softmax(x_y) - 1) / N - # dx_i = softmax(x_i) / N, i != y - # N is the number of non ignored elements in the batch - # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N - # = dx_i - (1 - label_smoothing) / N - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) - grad_dtype = X_block.dtype - X_block = X_block.to(tl.float32) - if label_smoothing > 0: - # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - # Scale gradients based on reduction mode - # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore - # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here - if reduce_loss: - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) - else: - X_block = tl.exp(X_block - m) / d - eps - tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) - - # We need tl.debug_barrier() to ensure the new result of X_ptr is written - tl.debug_barrier() - - # 5. Calculate the loss - - # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) - # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - loss = -(ori_X_y - m - tl.log(d)) - - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps - # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) - # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) - # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 - if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) - loss = loss * (1 - label_smoothing) + smooth_loss - - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx) - # Apply the same conditional scaling logic for the target token - if reduce_loss: - X_y += -(1 - label_smoothing) / (n_non_ignore) - else: - X_y += -(1 - label_smoothing) - tl.store(X_ptr + y - vocab_start_idx, X_y) - - tl.store(loss_ptr, loss) -======= ->>>>>>> 389a6b - from transformer_engine.common.triton.cross_entropy import ( online_softmax_kernel, cross_entropy_kernel, diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d91c07c45..86acb7932 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -468,7 +468,15 @@ def is_fp8_fnuz(): get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 -<<<<<<< HEAD +def assert_dim_for_all_gather( + tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer +) -> None: + """Assert that tensor dimensions are supported for all-gather""" + if with_all_gather: + assert quantizer.is_quantizable(tensor), ( + "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ + ) + def is_bf16_compatible() -> None: if IS_HIP_EXTENSION: # only MI200 and newer machines support bf16 @@ -481,24 +489,6 @@ def is_bf16_compatible() -> None: check on device compute capability to enforce sm_80 or higher. """ return torch.cuda.get_device_capability()[0] >= 8 -======= -def assert_dim_for_all_gather( - tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer -) -> None: - """Assert that tensor dimensions are supported for all-gather""" - if with_all_gather: - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) - - -def is_bf16_compatible() -> bool: - """Replaces torch.cuda.is_bf16_compatible() with an explicit - check on device compute capability to enforce sm_80 or higher. - """ - return torch.cuda.get_device_capability()[0] >= 8 ->>>>>>> 389a6b - def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: """ @@ -535,14 +525,11 @@ def is_non_tn_fp8_gemm_supported() -> bool: @functools.lru_cache(maxsize=None) def get_cudnn_version() -> Tuple[int, int, int]: """Runtime cuDNN version (major, minor, patch)""" -<<<<<<< HEAD # ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out if IS_HIP_EXTENSION: return (99, 0, 0) -======= import transformer_engine.pytorch.cpp_extensions as ext ->>>>>>> 389a6b encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude) From e60ff21fdd4420faf9573b32eb22f821ec32d585 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Tue, 10 Feb 2026 11:09:24 -0600 Subject: [PATCH 132/141] [ROCm] resolve the conflicts in setup --- pyproject.toml | 6 +----- setup.py | 15 +++++---------- 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index c32fc31a4..3814aabd0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,15 +1,11 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. [build-system] -<<<<<<< HEAD -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax", "flax>=0.7.1"] -======= requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] ->>>>>>> 389a6b # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" diff --git a/setup.py b/setup.py index c6199c387..bec4943e1 100644 --- a/setup.py +++ b/setup.py @@ -245,17 +245,17 @@ def git_check_submodules() -> None: cmdclass = {} package_data = {} include_package_data = False -<<<<<<< HEAD - install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) -======= install_requires = [] ->>>>>>> 389a6b extras_require = { "core": [f"transformer_engine_cu12=={__version__}"], "core_cu12": [f"transformer_engine_cu12=={__version__}"], "core_cu13": [f"transformer_engine_cu13=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], + } if not rocm_build() else { + "core": [f"transformer_engine_{te_cuda_vers}=={__version__}"], + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], } else: install_requires, test_requires = setup_requirements() @@ -303,13 +303,8 @@ def git_check_submodules() -> None: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, -<<<<<<< HEAD - cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", -======= - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} if not rocm_build() else {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=f">={min_python_version_str()}", ->>>>>>> 389a6b classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), From 8bbb16214277009bed6a8327ed6312a5b44b3f59 Mon Sep 17 00:00:00 2001 From: Ye Wang Date: Wed, 11 Feb 2026 10:55:18 -0600 Subject: [PATCH 133/141] [ROCm] resolve the cpp gtest --- tests/cpp/operator/CMakeLists.txt | 38 ++---------- tests/cpp/operator/test_cast_mxfp8.cu | 45 ++++---------- .../operator/test_cast_mxfp8_gated_swiglu.cu | 60 +++++++------------ tests/cpp/test_common.cu | 42 ++++--------- tests/cpp/test_common.h | 13 +--- 5 files changed, 52 insertions(+), 146 deletions(-) diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index f0123ccf7..cd36993ce 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -4,7 +4,6 @@ # # See LICENSE for license information. -<<<<<<< HEAD list(APPEND test_cuda_sources test_cast.cu test_cast_current_scaling.cu @@ -24,16 +23,18 @@ list(APPEND test_cuda_sources test_act.cu test_normalization.cu test_normalization_mxfp8.cu + test_memset.cu test_multi_cast_transpose.cu test_multi_padding.cu test_multi_unpadding.cu test_causal_softmax.cu - test_swizzle.cu test_swap_first_dims.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu) + test_cast_nvfp4_transpose.cu + test_cast_float8blockwise.cu + test_swizzle.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu) @@ -70,37 +71,6 @@ else() add_executable(test_operator ${test_hip_sources}) endif() -======= -add_executable(test_operator - test_cast.cu - test_cast_current_scaling.cu - test_cast_dbias.cu - test_cast_dbias_dgelu.cu - test_cast_gated_swiglu.cu - test_cast_mxfp8_gated_swiglu.cu - test_qdq.cu - test_cast_mxfp8.cu - test_cast_nvfp4_transpose.cu - test_cast_float8blockwise.cu - test_dequantize_mxfp8.cu - test_transpose.cu - test_cast_transpose.cu - test_cast_transpose_current_scaling.cu - test_cast_transpose_dbias.cu - test_cast_transpose_dbias_dgelu.cu - test_cast_transpose_dgeglu.cu - test_act.cu - test_normalization.cu - test_normalization_mxfp8.cu - test_memset.cu - test_multi_cast_transpose.cu - test_multi_padding.cu - test_multi_unpadding.cu - test_causal_softmax.cu - test_swizzle.cu - test_swap_first_dims.cu - ../test_common.cu) ->>>>>>> 389a6b # Find required packages find_package(OpenMP REQUIRED) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index acba36464..a029e4f3f 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -315,29 +315,22 @@ void performTest_x1(const ProcessingMethod processing_method, #ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices; #else const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; #endif - std::vector mismatches_scales_indices; size_t mismatches_scales = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= - compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; @@ -510,50 +503,36 @@ void performTest_x2(const ProcessingMethod processing_method, #ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices_rowwise; + std::vector mismatches_scales_indices_colwise; #else const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; #endif - std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales_rowwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b - std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales_colwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index d6bcfef30..ba4144a7c 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -264,7 +264,9 @@ void performTest_x1(const size_t rows, rowwise, colwise); +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices; +#endif size_t mismatches_scales = 0; const size_t scale_diff_abs_tolerance = 0; const double abs_tolerable_mismatches_limit = 1.0; @@ -274,25 +276,11 @@ void performTest_x1(const size_t rows, ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { -<<<<<<< HEAD - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); - } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, @@ -300,12 +288,13 @@ void performTest_x1(const size_t rows, } else { compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); - ->>>>>>> 389a6b } #ifdef __HIP_PLATFORM_AMD__ @@ -411,44 +400,35 @@ void performTest_x2(const size_t rows, const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices_rowwise; +#endif size_t mismatches_scales_rowwise = 0; -<<<<<<< HEAD - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); - std::vector mismatches_scales_indices_colwise; - size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); -======= compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif mismatches_scales_rowwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ + std::vector mismatches_scales_indices_colwise; +#endif size_t mismatches_scales_colwise = 0; compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif mismatches_scales_colwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); ->>>>>>> 389a6b - #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, @@ -514,7 +494,7 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam bool>> {}; TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { - #ifdef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors. #else // #ifdef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 3286b1527..5427bc118 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -412,20 +412,13 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { -<<<<<<< HEAD - (void)cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); -======= const DType colwise_type = tensor_.dtype(); const size_t colwise_size = bytes(s, colwise_type); - cudaMemcpy(cpu_data_columnwise_.get(), + (void)cudaMemcpy(cpu_data_columnwise_.get(), tensor_.get_columnwise_data().data_ptr, colwise_size, cudaMemcpyDeviceToHost); ->>>>>>> 389a6b } if (isFp8Type(dtype()) || isFp4Type(dtype())) { if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { @@ -759,14 +752,6 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -<<<<<<< HEAD -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) -======= template struct CastToType; @@ -783,10 +768,12 @@ struct CastToType { template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit) ->>>>>>> 389a6b { using UpcastType = typename CastToType::type; auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); @@ -796,6 +783,9 @@ void compare_scaling_factors(const std::string &name, const T *test, const T *re const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); mismatches_num = 0; +#ifndef __HIP_PLATFORM_AMD__ + std::vector mismatch_indices; +#endif //#ifndef __HIP_PLATFORM_AMD__ for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { @@ -842,8 +832,6 @@ void compare_scaling_factors(const std::string &name, const T *test, const T *re } } -<<<<<<< HEAD - #ifdef __HIP_PLATFORM_AMD__ void adjust_ref_for_e8m0_scale_error(const std::string &name, const std::vector &mismatch_idx, @@ -887,11 +875,13 @@ void adjust_ref_for_e8m0_scale_error(const std::string &name, } } #endif // #ifdef __HIP_PLATFORM_AMD__ -======= // Instantiate templates template void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit); @@ -899,11 +889,13 @@ void compare_scaling_factors(const std::string &name, const uint8_t *te template void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit); ->>>>>>> 389a6b std::pair getTolerances(const DType type) { switch(type) { @@ -1069,13 +1061,6 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -<<<<<<< HEAD -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - (void)cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; -======= bool isFp4Type(DType type) { return type == DType::kFloat4E2M1; } @@ -1084,7 +1069,6 @@ int32_t getDeviceComputeCapability() { cudaDeviceProp deviceProp; cudaGetDeviceProperties(&deviceProp, 0); return 10 * deviceProp.major + deviceProp.minor; ->>>>>>> 389a6b } size_t first_dimension(const std::vector &shape) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 1d6d9107e..56154c9d9 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -488,24 +488,17 @@ void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -<<<<<<< HEAD -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); -======= template void compare_scaling_factors(const std::string &name, const T *test, const T *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef USE_ROCM + std::vector& mismatch_indices, +#endif //#ifdef USE_ROCM size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, const double abs_tolerable_mismatches_limit = 0, const double rel_tolerable_mismatches_limit = 0); ->>>>>>> 389a6b - #ifdef USE_ROCM void adjust_ref_for_e8m0_scale_error(const std::string &name, const std::vector &mismatch_idx, From f573b40081340199654061e518abb9c195e96a81 Mon Sep 17 00:00:00 2001 From: alextmagro Date: Wed, 11 Feb 2026 21:02:08 +0000 Subject: [PATCH 134/141] [ROCm] resolve pytorch and jax tests Resolve wheels and examples --- build_tools/wheel_utils/build_wheels.sh | 103 +++++++----------- examples/pytorch/mnist/main.py | 20 +--- tests/jax/distributed_test_base.py | 6 +- tests/jax/test_custom_call_compute.py | 75 +------------ tests/jax/test_distributed_layernorm_mlp.py | 7 -- tests/jax/test_fused_attn.py | 4 - .../attention/run_attention_with_cp.py | 35 +----- tests/pytorch/attention/test_attention.py | 58 +--------- .../attention/test_attention_with_cp.py | 24 +--- tests/pytorch/attention/test_kv_cache.py | 10 +- tests/pytorch/distributed/run_fsdp2_model.py | 18 +-- tests/pytorch/distributed/test_fusible_ops.py | 5 +- tests/pytorch/test_cpu_offloading.py | 6 - .../test_float8_current_scaling_exact.py | 7 -- tests/pytorch/test_fusible_ops.py | 17 --- tests/pytorch/test_numerics.py | 48 +------- tests/pytorch/test_recipe.py | 7 +- tests/pytorch/utils.py | 16 +-- 18 files changed, 65 insertions(+), 401 deletions(-) diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 6db223691..0be852c8a 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -43,19 +43,10 @@ else fi if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install setuptools wheel -fi - -# Install deps -<<<<<<< HEAD -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install pybind11[global] ninja + ${PYBINDIR}pip install pybind11[global] ninja setuptools wheel else - ${PYBINDIR}pip install cmake pybind11[global] ninja + /opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel fi -======= -/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel ->>>>>>> 389a6b if $BUILD_METAPACKAGE ; then cd /TransformerEngine @@ -83,70 +74,52 @@ if $BUILD_COMMON ; then # Create the wheel. ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt -<<<<<<< HEAD - # Repack the wheel for cuda specific package, i.e. cu12. - ${PYBINDIR}wheel unpack dist/* - # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" - ${PYBINDIR}wheel pack ${WHL_BASE} -======= - # Repack the wheel for specific cuda version. - /opt/python/cp310-cp310/bin/wheel unpack dist/* - # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info" - /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} ->>>>>>> 389a6b + if [ "$ROCM_BUILD" = "1" ]; then + # Repack the wheel for cuda specific package, i.e. cu12. + ${PYBINDIR}wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" + ${PYBINDIR}wheel pack ${WHL_BASE} + else + # Repack the wheel for specific cuda version. + /opt/python/cp310-cp310/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info" + /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} + fi # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) IFS='-' read -ra whl_parts <<< "$whl_name" -<<<<<<< HEAD whl_name_target="${whl_parts[0]}_${TE_CUDA_VERS}-${whl_parts[1]}-py3-none-${whl_parts[4]}" -======= - whl_name_target="${whl_parts[0]}_cu${CUDA_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}" ->>>>>>> 389a6b rm -rf $WHL_BASE dist mv *.whl /wheelhouse/"$whl_name_target" fi if $BUILD_PYTORCH ; then -<<<<<<< HEAD - cd /TransformerEngine/transformer_engine/pytorch - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 - else - PYBINDIR=/opt/python/cp38-cp38/bin/ - ${PYBINDIR}pip install torch - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ -fi - -if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install jax - else - PYBINDIR=/opt/python/cp310-cp310/bin/ - ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ -======= - cd /TransformerEngine/transformer_engine/pytorch - /opt/python/cp310-cp310/bin/pip install torch - /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ + cd /TransformerEngine/transformer_engine/pytorch + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install torch + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + cp dist/* /wheelhouse/ fi if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - /opt/python/cp310-cp310/bin/pip install "jax[cuda${CUDA_MAJOR}_local]" jaxlib - /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ ->>>>>>> 389a6b + cd /TransformerEngine/transformer_engine/jax + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install jax + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + cp dist/* /wheelhouse/ fi diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index 3516d5275..347d36e7c 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -54,17 +54,8 @@ def train(args, model, device, train_loader, optimizer, epoch, use_amp, use_fp8) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() -<<<<<<< HEAD - if use_amp: - with autocast(device_type='cuda', dtype=torch.float16): - output = model(data) - else: - with te.fp8_autocast(enabled=use_fp8): - output = model(data) -======= with te.autocast(enabled=use_fp8): output = model(data) ->>>>>>> 389a6b loss = F.nll_loss(output, target) loss.backward() optimizer.step() @@ -99,17 +90,8 @@ def test(model, device, test_loader, use_amp, use_fp8): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) -<<<<<<< HEAD - if use_amp: - with autocast(device_type='cuda', dtype=torch.float16): - output = model(data) - else: - with te.fp8_autocast(enabled=use_fp8): - output = model(data) -======= with te.autocast(enabled=use_fp8): output = model(data) ->>>>>>> 389a6b test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 244e8b5ee..e8d9cefd6 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,10 +10,6 @@ import pytest import jax -<<<<<<< HEAD -from jax._src.pjit import pjit -======= ->>>>>>> 389a6b from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index ad5ddf0d3..75b606a62 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -614,16 +614,12 @@ def test_norm_forward_with_tensor_scaling_fp8( ) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) -<<<<<<< HEAD - @pytest.mark.parametrize("out_dtype", FP8_COMPUTE_TYPE) -======= @pytest.mark.parametrize( "out_dtype", [ - jnp.float8_e4m3fn, + jnp_float8_e4m3_type if is_hip_extension() else jnp.float8_e4m3fn, ], ) ->>>>>>> 389a6b def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype ): @@ -640,15 +636,9 @@ def test_norm_forward_with_block_scaling_fp8( ) -<<<<<<< HEAD QUANTIZE_OUTPUT_DTYPES = { "L0": [jnp_float8_e4m3_type], "L2": FP8_COMPUTE_TYPE, -======= -QUANTIZE_OUTPUT_FP8_DTYPES = { - "L0": [jnp.float8_e4m3fn], - "L2": [jnp.float8_e4m3fn, jnp.float8_e5m2], ->>>>>>> 389a6b } QUANTIZE_OUTPUT_DTYPES = { test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn] @@ -692,11 +682,7 @@ def test_norm_forward_with_block_scaling_fp8( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) -<<<<<<< HEAD @pytest_parametrize_wrapper("q_dtype", FP8_COMPUTE_TYPE) -======= -@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn]) ->>>>>>> 389a6b @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper( @@ -1085,13 +1071,8 @@ def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, w @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) -<<<<<<< HEAD @pytest_parametrize_wrapper("q_dtype", [jnp_float8_e4m3_type]) -@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) -======= -@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) ->>>>>>> 389a6b @pytest_parametrize_wrapper("flatten_axis", [-1]) @pytest_parametrize_wrapper("with_group_sizes", [True, False]) @pytest_parametrize_wrapper( @@ -1487,17 +1468,10 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( -<<<<<<< HEAD - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, -======= fp8_recipe=recipe, quantize_meta_set=QuantizeMetaSet( x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() ), ->>>>>>> 389a6b ) n_iterations = 3 if recipe.delayed() else 1 @@ -1511,17 +1485,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) -<<<<<<< HEAD - assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) -======= assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype) assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype) ->>>>>>> 389a6b @pytest.fixture(name="random_inputs") @@ -1568,17 +1535,10 @@ def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) quantizer_set = QuantizerFactory.create_set( -<<<<<<< HEAD - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, -======= fp8_recipe=recipe, quantize_meta_set=QuantizeMetaSet( x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() ), ->>>>>>> 389a6b ) if norm_type == "layernorm": @@ -1624,21 +1584,12 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) -<<<<<<< HEAD - assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) - if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp_float8_e5m2_type) -======= assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype) assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype) if beta is not None: assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype) ->>>>>>> 389a6b @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @@ -1676,17 +1627,10 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, -<<<<<<< HEAD - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, -======= fp8_recipe=recipe, quantize_meta_set=QuantizeMetaSet( x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() ), ->>>>>>> 389a6b ) if norm_type == "layernorm": @@ -1754,20 +1698,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) -<<<<<<< HEAD - assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp_float8_e5m2_type) - if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp_float8_e5m2_type) - - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp_float8_e5m2_type) - if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp_float8_e5m2_type) - - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) -======= fwd_dtype = quantizer_sets[0].x.q_dtype bwd_dtype = quantizer_sets[0].dgrad.q_dtype assert_allclose(prim_out, ref_out, dtype=fwd_dtype) @@ -1778,7 +1708,6 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): if use_bias: assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype) assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype) ->>>>>>> 389a6b # E5M2 * E5M2 is not supported diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index d58ebcef5..c67528f04 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -275,23 +275,16 @@ def _test_layernorm_mlp_grad( ) # +1 for multi_gpus multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) -<<<<<<< HEAD # TODO: skip cases with single fwd as nan/inf if is_hip_extension() and (jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd))): pytest.skip("skip tests with nan/inf single fwd.") - - fwd_test_type = dtype if fp8_recipe is None else jnp_float8_e4m3_type - bwd_test_type = dtype if fp8_recipe is None else jnp_float8_e5m2_type -======= - fwd_test_type = bwd_test_type = dtype if quantization_recipe is not None: quantize_config = get_quantize_config_with_recipe(quantization_recipe) fwd_test_type = quantize_config.FWD_DTYPE bwd_test_type = quantize_config.BWD_DTYPE ->>>>>>> 389a6b if fwd_test_type == jnp.float16 and use_bias: assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a29725909..f33961455 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -24,12 +24,8 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike -<<<<<<< HEAD -from transformer_engine.jax import fp8_autocast from transformer_engine.jax.cpp_extensions.misc import is_hip_extension -======= from transformer_engine.jax import autocast ->>>>>>> 389a6b from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 8fc914053..80f21048f 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,23 +10,16 @@ from contextlib import nullcontext import torch import torch.distributed as dist -<<<<<<< HEAD +import warnings + from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.attention import DotProductAttention -======= ->>>>>>> 389a6b + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn -<<<<<<< HEAD -from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer -from transformer_engine.common.recipe import DelayedScaling -import warnings -======= from transformer_engine.pytorch import ( autocast, DotProductAttention, @@ -35,7 +28,6 @@ ) from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling from utils import ModelConfig, compare_and_assert ->>>>>>> 389a6b dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} @@ -336,16 +328,11 @@ def run_dpa_with_cp( core_attention_bias=bias, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, -<<<<<<< HEAD cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), -======= - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ->>>>>>> 389a6b ) if config.return_max_logit: out, max_logit = out @@ -438,16 +425,11 @@ def run_dpa_with_cp( core_attention_bias=bias_, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, -<<<<<<< HEAD cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1], cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), -======= - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, fp8_output=fp8_mha, ->>>>>>> 389a6b ) if config.return_max_logit: out_, max_logit_ = out_ @@ -491,17 +473,10 @@ def run_dpa_with_cp( for x in [dq_, dk_, dv_, out_] ] elif qkv_format == "thd": -<<<<<<< HEAD - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size -======= dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] - cu_seqlens_q_padded = cu_seqlens_q_padded // world_size ->>>>>>> 389a6b + cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 23463cc32..c0cf64803 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -12,12 +12,9 @@ import pytest import torch -<<<<<<< HEAD from torch.utils.cpp_extension import IS_HIP_EXTENSION -======= from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype ->>>>>>> 389a6b from transformer_engine.common import recipe from transformer_engine.pytorch import ( TransformerLayer, @@ -91,7 +88,6 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -<<<<<<< HEAD if IS_HIP_EXTENSION: from utils import EnvVarCleaner @pytest.fixture(autouse=True) @@ -101,13 +97,11 @@ def reset_attn_backend(): "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) yield -======= # Define F16 data types to test param_types = [torch.float16] if is_bf16_available(): param_types.append(torch.bfloat16) param_types_lean = [torch.bfloat16] ->>>>>>> 389a6b model_configs_base = { # test: ModelConfig(b, sq, hq, dqk) @@ -126,7 +120,6 @@ def reset_attn_backend(): } -<<<<<<< HEAD param_types = [torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) @@ -136,10 +129,8 @@ def reset_attn_backend(): # backend is capable of supporting it. @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_gqa_mla_thd(): - """ - Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration - post-processing for BWD FA with native padding support. - """ + """Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration + post-processing for BWD FA with native padding support.""" # b, sq, h, dqk config = ModelConfig(8, 128, 16, 128, num_gqa_groups= 4, head_dim_v=64, attn_mask_type="padding") qkv_layout = "thd_thd_thd" @@ -156,11 +147,10 @@ def test_gqa_mla_thd(): test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False) + @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): - """ - Non-regression test for memory workspace calculation integer overflow issue. - """ + """Non-regression test for memory workspace calculation integer overflow issue.""" ckpt_attn = False pad_between_seqs = False if not is_bf16_compatible(): @@ -197,8 +187,6 @@ def test_dot_product_mem_calc(): del os.environ["NVTE_FUSED_ATTN_AOTRITON"] -======= ->>>>>>> 389a6b @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model_configs", [model_configs_base]) @@ -306,13 +294,9 @@ def test_dot_product_attention( ) if len(fused_attn_backends) == 2: os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" -<<<<<<< HEAD os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( -======= fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention( ->>>>>>> 389a6b dtype, config, "FusedAttention", @@ -324,15 +308,11 @@ def test_dot_product_attention( share_cu_seqlens_ref, # Not used by AOT ) os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" -<<<<<<< HEAD os.environ["NVTE_FUSED_ATTN_CK"] = "1" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" os.environ["NVTE_CK_USES_FWD_V3"] = "1" os.environ["NVTE_CK_USES_BWD_V3"] = "1" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( -======= fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention( ->>>>>>> 389a6b dtype, config, "FusedAttention", @@ -1926,37 +1906,7 @@ def get_model(dtype, config): qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] -<<<<<<< HEAD -def _rmse(a, b): - return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) - - -def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): - logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) - logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) - try: - if a.dtype != b.dtype: - a = a.to(b.dtype) - torch.testing.assert_close(a, b, atol=atol, rtol=rtol) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert rmse < rmse_tol * rmse_range, ( - name_a - + " vs " - + name_b - + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - ) - - @pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm") -======= ->>>>>>> 389a6b @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index a9b0afe89..9ac96dcef 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -12,12 +12,10 @@ import pytest import torch -<<<<<<< HEAD + from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.utils import ( -======= + from transformer_engine.pytorch import ( ->>>>>>> 389a6b get_device_compute_capability, get_cudnn_version, ) @@ -87,13 +85,8 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") -<<<<<<< HEAD @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) -======= -@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", dtypes) ->>>>>>> 389a6b @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) @@ -123,11 +116,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ) if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") -<<<<<<< HEAD if IS_HIP_EXTENSION: if config.head_dim_qk != config.head_dim_v and not FlashAttentionUtils.v3_is_installed: pytest.skip("MLA FlashAttention requires v3+!") -======= dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} available_backends, *_ = get_available_attention_backends( config, @@ -137,7 +128,6 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): flash_attn_supported, *_ = available_backends if not flash_attn_supported: pytest.skip("No attention backend available.") ->>>>>>> 389a6b subprocess.run( get_bash_arguments( @@ -207,13 +197,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") -<<<<<<< HEAD @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) -======= -@pytest.mark.skipif(get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") @pytest.mark.parametrize("dtype", dtypes) ->>>>>>> 389a6b @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) @@ -235,15 +220,12 @@ def test_cp_with_fused_attention( pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!") if (not IS_HIP_EXTENSION) and dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") -<<<<<<< HEAD if IS_HIP_EXTENSION and dtype == "fp8": - pytest.skip("FP8 attention has not been supported on ROCm yet!") -======= + pytest.skip("FP8 attention is not supported on ROCm yet!") if dtype == "fp8" and not fp8_dpa and fp8_mha: pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") if dtype != "fp8" and fp8_bwd: pytest.skip("Only fp8 works with fp8_bwd=True!") ->>>>>>> 389a6b config = model_configs_fused_attn[model] config.context_parallel = True diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index bab34ef28..eb86c0776 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -15,19 +15,13 @@ import pytest import torch -from torch.distributions import Exponential -<<<<<<< HEAD from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch import make_graphed_callables -from transformer_engine.common import recipe -from transformer_engine.pytorch import fp8_autocast, fp8_model_init -from transformer_engine.pytorch.transformer import ( -======= + +from torch.distributions import Exponential from transformer_engine.pytorch import ( make_graphed_callables, autocast, quantized_model_init, ->>>>>>> 389a6b TransformerLayer, DotProductAttention, InferenceParams, diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 2ec97518f..b9fe33593 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -17,6 +17,7 @@ Float8CurrentScaling, MXFP8BlockScaling, ) +from transformer_engine.pytorch import torch_version import torch import torch.distributed as dist @@ -28,23 +29,8 @@ from torch.distributed.device_mesh import init_device_mesh from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext -<<<<<<< HEAD -from transformer_engine.pytorch import torch_version - -class SimpleNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super(SimpleNet, self).__init__() - self.fc1 = te.Linear(input_size, hidden_size) - self.fc2 = te.Linear(hidden_size, output_size) - - def forward(self, x): - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x -======= LOCAL_RANK = None ->>>>>>> 389a6b def dist_print(msg): diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 3ce4ca7cd..85ae2d85b 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -31,10 +31,7 @@ is_bf16_available, ) import transformer_engine.pytorch.ops as te_ops -<<<<<<< HEAD -from transformer_engine.pytorch.utils import is_bf16_compatible, is_fp8_fnuz -======= ->>>>>>> 389a6b +from transformer_engine.pytorch.utils import is_fp8_fnuz import transformer_engine_torch as tex # Import utility functions diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index f3f649fb8..c5b4b48b6 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -46,15 +46,9 @@ NUM_LAYERS = model_config["small"].num_layers EPSILON = model_config["small"].eps -<<<<<<< HEAD -# Flash attention saves some internal tensor for the backward pass -# that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" -======= # Disable garbage collection to tests if there are reference cycles. # We do not want them, because they can result in CUDA out of memory errors. import gc ->>>>>>> 389a6b gc.disable() diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index c5c26e873..d21c2e366 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -9,15 +9,8 @@ import transformer_engine.pytorch as te -<<<<<<< HEAD -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.common.recipe import Float8CurrentScaling, Format -from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype -======= from transformer_engine.common.recipe import Float8CurrentScaling from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype ->>>>>>> 389a6b # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 37b5e9ee9..a67fd4f45 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,12 +35,7 @@ NVFP4Quantizer, is_bf16_available, ) -<<<<<<< HEAD -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer -from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import get_device_compute_capability -======= ->>>>>>> 389a6b import transformer_engine_torch as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -1377,20 +1372,8 @@ def test_rmsnorm( # Expected numerical error tols = dtype_tols(dtype) - # Explicit checks for quantization if quantized_compute: -<<<<<<< HEAD - tols = dtype_tols(y_test._quantizer.dtype) - expected_tensor_cls = { - Float8Quantizer:Float8Tensor, - Float8CurrentScalingQuantizer:Float8Tensor, - MXFP8Quantizer:MXFP8Tensor - }[type(y_test._quantizer)] - assert isinstance(y_test, expected_tensor_cls) - y_test = y_test.dequantize(dtype=torch.float32) -======= tols = quantization_tols(quantization) ->>>>>>> 389a6b # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 0aa932dfc..50578fc1a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -38,15 +38,6 @@ LayerNorm, Fp8Padding, Fp8Unpadding, -<<<<<<< HEAD -) -from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils -from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import ( -======= ->>>>>>> 389a6b Float8Quantizer, Float8CurrentScalingQuantizer, MXFP8Quantizer, @@ -57,19 +48,15 @@ is_bf16_available, is_nvfp4_available, ) +from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils from transformer_engine.pytorch import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.common import recipe import transformer_engine_torch as tex -<<<<<<< HEAD -from utils import ModelConfig, reset_rng_states, get_available_attention_backends +from utils import ModelConfig, reset_rng_states if IS_HIP_EXTENSION: from utils import EnvVarCleaner -======= -from utils import ModelConfig, reset_rng_states ->>>>>>> 389a6b - # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) @@ -202,28 +189,6 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> use_cutlass_grouped_gemm.append(True) -<<<<<<< HEAD -def is_fused_attn_available( - config: ModelConfig, - dtype: torch.dtype, - qkv_layout="bshd_bshd_bshd", - is_training=True, - deterministic=False, -): - _, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - deterministic=deterministic, - ) - if IS_HIP_EXTENSION: - return fused_attn_backends != [] - return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends - - -======= ->>>>>>> 389a6b def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -806,7 +771,6 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") -<<<<<<< HEAD if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): if (dtype == torch.bfloat16 and not fp8 @@ -814,13 +778,11 @@ def test_gpt_full_activation_recompute( and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") -======= if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" ) ->>>>>>> 389a6b config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow @@ -972,12 +934,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] -<<<<<<< HEAD - if not is_fused_attn_available(config, dtype, deterministic=True): - pytest.skip("No attention backend available.") - -======= ->>>>>>> 389a6b outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index d8870d3da..6850be9b4 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -30,12 +30,7 @@ _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -<<<<<<< HEAD from transformer_engine.pytorch.utils import is_fp8_fnuz -from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear -from transformer_engine.pytorch.distributed import fp8_autocast -======= ->>>>>>> 389a6b from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 28236a18b..05555626b 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -15,12 +15,6 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine -<<<<<<< HEAD -import transformer_engine.common.recipe -import transformer_engine.pytorch as te -from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type -======= ->>>>>>> 389a6b import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import InferenceParams @@ -32,6 +26,7 @@ check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type torch_float8_e4m3_type = get_torch_float8_e4m3_type() torch_float8_e5m2_type = get_torch_float8_e5m2_type() @@ -105,15 +100,10 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) - if dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz: + if dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 -<<<<<<< HEAD - if dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 -======= - if dtype == torch.float8_e5m2: + if dtype in (torch.float8_e5m2, torch.float8_e5m2fnuz): return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 ->>>>>>> 389a6b raise ValueError(f"Unsupported dtype ({dtype})") From eaaae946f976c5a3ef634bea4193bbde2f68743a Mon Sep 17 00:00:00 2001 From: alextmagro Date: Thu, 19 Feb 2026 18:49:23 +0000 Subject: [PATCH 135/141] pytest, example, wheels conflict resolution --- ci/pytorch.sh | 3 ++- tests/pytorch/test_cpu_offloading.py | 11 +++++++++++ tests/pytorch/test_cpu_offloading_v1.py | 4 +++- .../test_float8_current_scaling_exact.py | 6 ++++-- .../test_layernorm_saved_tensors_logic.py | 8 ++++---- tests/pytorch/test_numerics.py | 6 ++++-- .../transformer_engine/hadamard_transform.h | 6 ++++++ transformer_engine/common/recipe/__init__.py | 12 ++++++------ transformer_engine/pytorch/csrc/common.h | 2 ++ .../pytorch/csrc/extensions/activation.cpp | 12 ++++++++++++ .../pytorch/csrc/extensions/bias.cpp | 6 ++++++ .../pytorch/csrc/extensions/cast.cpp | 4 ++++ .../pytorch/csrc/extensions/normalization.cpp | 16 ++++++++++++++++ transformer_engine/pytorch/csrc/pybind.h | 4 ++++ transformer_engine/pytorch/csrc/quantizer.cpp | 2 ++ transformer_engine/pytorch/module/base.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 8 ++++---- transformer_engine/pytorch/module/linear.py | 2 +- transformer_engine/pytorch/quantization.py | 1 + .../pytorch/tensor/mxfp8_tensor.py | 4 +++- .../pytorch/triton_kernels/cast.py | 18 +++++++++--------- .../pytorch/triton_kernels/layernorm.py | 2 +- .../pytorch/triton_kernels/rmsnorm.py | 2 +- 24 files changed, 108 insertions(+), 35 deletions(-) diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..5558beca3 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -51,7 +51,8 @@ run_test_config(){ run_default_fa 1 test_deferred_init.py run_default_fa 1 test_float8tensor.py run_default_fa 1 test_float8_current_scaling_exact.py - test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 run 1 test_cpu_offloading.py + run 1 test_cpu_offloading.py + test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py run_default_fa 1 test_fused_rope.py run_default_fa 1 test_fused_router.py run_default_fa 1 test_fusible_ops.py diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index c5b4b48b6..4e4c71e14 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -21,6 +23,8 @@ from utils import ModelConfig import transformer_engine_torch as tex +from torch.utils.cpp_extension import IS_HIP_EXTENSION + # Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() @@ -626,6 +630,13 @@ def test_numerics( "Fused attention + cuda graphs is temporarily broken, not because of cpu offloading" ) + if (IS_HIP_EXTENSION + and backend == "FusedAttention" + and not use_cuda_graphs + and layer_type in ("multihead_attention", "transformer_layer") + ): + pytest.skip("No dot product attention backend is available for the provided inputs") + os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "0" os.environ["NVTE_UNFUSED_ATTN"] = "0" diff --git a/tests/pytorch/test_cpu_offloading_v1.py b/tests/pytorch/test_cpu_offloading_v1.py index 8a8e03630..07091ee7a 100644 --- a/tests/pytorch/test_cpu_offloading_v1.py +++ b/tests/pytorch/test_cpu_offloading_v1.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -34,7 +36,7 @@ # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN") == "0" +assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" # CPU offload v1 code path is enabled assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index d21c2e366..21fe4700b 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,7 +11,7 @@ import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Float8CurrentScaling +from transformer_engine.common.recipe import Float8CurrentScaling, Format from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype @@ -847,7 +849,7 @@ def test_fp8_current_scaling_linear_large_numel_e4m3(self, dtype, shape): pytest.skip(f"Skipping {shape}: insufficient device memory for allocation.") try: - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with autocast(enabled=True, recipe=recipe): y = layer(x) except torch.OutOfMemoryError: pytest.skip(f"Skipping {shape}: OOM during forward.") diff --git a/tests/pytorch/test_layernorm_saved_tensors_logic.py b/tests/pytorch/test_layernorm_saved_tensors_logic.py index cb7760b5f..ab7d5d9d4 100644 --- a/tests/pytorch/test_layernorm_saved_tensors_logic.py +++ b/tests/pytorch/test_layernorm_saved_tensors_logic.py @@ -1,11 +1,11 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import pytest import torch import torch.nn as nn from unittest.mock import patch -from transformer_engine.pytorch import LayerNormLinear, LayerNormMLP, fp8_autocast +from transformer_engine.pytorch import LayerNormLinear, LayerNormMLP, autocast from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -84,7 +84,7 @@ def spy_on_ctx(ctx, *args, **kwargs): weight_tensor.requires_grad_(True) with patch(config["backward_target"], side_effect=spy_on_ctx) as mock_backward: - with fp8_autocast(enabled=True): + with autocast(enabled=True): out, ln_out_returned = model(inp) out.backward(grad_output, retain_graph=True) @@ -99,7 +99,7 @@ def spy_on_ctx(ctx, *args, **kwargs): saved_ln_out_container.clear() with patch(config["backward_target"], side_effect=spy_on_ctx) as mock_backward: - with fp8_autocast(enabled=True): + with autocast(enabled=True): out, ln_out_returned = model(inp) out.backward(grad_output) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 50578fc1a..9fe6304d4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1196,7 +1196,7 @@ def _test_granular_accuracy_with_fp8(block, bs, dtype, config): ) inp_hidden_states.retain_grad() - with fp8_autocast(enabled=True): + with autocast(enabled=True): out = block(inp_hidden_states) loss = out.sum() loss.backward() @@ -1357,10 +1357,11 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model module = LayerNormLinear config = model_configs[model] - with fp8_model_init(enabled=fp8_model_params): + with quantized_model_init(enabled=fp8_model_params): layer = module( config.hidden_size, 4 * config.hidden_size, + config.eps, bias=True, params_dtype=dtype, device="cuda", @@ -1371,6 +1372,7 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model ref_layer = module( config.hidden_size, 4 * config.hidden_size, + config.eps, bias=True, params_dtype=dtype, device="cuda", diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h index a0dd325da..73edf23a3 100644 --- a/transformer_engine/common/include/transformer_engine/hadamard_transform.h +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -11,6 +13,8 @@ #ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ #define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ +#ifndef __HIP_PLATFORM_AMD__ + #include "transformer_engine.h" #ifdef __cplusplus @@ -65,4 +69,6 @@ void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTE } // extern "C" #endif +#endif + #endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index c55f1f612..674d4e4cb 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -33,6 +33,7 @@ class _FormatMaxVals(Enum): """ Tuples of FP8 (OCP, FNUZ) values for different formats. """ + E2M1 = (6, 6) E4M3 = (448, 240) E5M2 = (57344, 57344) @@ -53,12 +54,11 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - #TODO: bring E2M1 back after rocm support MXFP4 - if not te_rocm_build: - E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) - E4M3 = _FormatHelper(max_fwd=_FormatMaxVals.E4M3.value, max_bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(max_fwd=_FormatMaxVals.E5M2.value, max_bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) + #TODO: Change max vals after rocm support MXFP4 + E2M1 = _FormatHelper(fwd=_FormatMaxVals.E2M1.value, bwd=_FormatMaxVals.E2M1.value) + E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) + HYBRID = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E5M2.value) @dataclass(frozen=True) diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 74852b22d..55d8aafd6 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -293,6 +293,7 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; +#ifndef __HIP_PLATFORM_AMD__ class NVFP4Quantizer : public Quantizer { public: // fp4 dtype @@ -346,6 +347,7 @@ class NVFP4Quantizer : public Quantizer { void quantize_impl(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag, bool compute_amax); }; +#endif // #ifndef __HIP_PLATFORM_AMD__ std::unique_ptr convert_quantizer(py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index de1e3ccbd..ebdbb5817 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -41,6 +41,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -51,6 +54,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -101,6 +105,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { @@ -119,6 +124,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; +#endif default: NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } @@ -153,6 +159,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -163,6 +172,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -213,6 +223,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { @@ -231,6 +242,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; +#endif default: NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index dcff95887..1d3e27a14 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -151,6 +151,9 @@ std::vector dact_dbias( impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -161,6 +164,7 @@ std::vector dact_dbias( impl = Impl::FUSED_DACT_AMAX_NVFP4; } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -220,6 +224,7 @@ std::vector dact_dbias( fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_DACT_AMAX_NVFP4: // Fused dact-amax kernel, unfused dbias and NVFP4 quantize { @@ -237,6 +242,7 @@ std::vector dact_dbias( nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } +#endif default: NVTE_ERROR("Invalid implementation"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 48c02215f..97b9e7ca6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -493,6 +493,7 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } +#ifndef __HIP_PLATFORM_AMD__ // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate @@ -693,6 +694,7 @@ std::tuple, std::vector> bulk_allocate_nv return retval; } +#endif // #ifndef __HIP_PLATFORM_AMD__ } // namespace @@ -791,6 +793,7 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); +#ifndef __HIP_PLATFORM_AMD__ } else if (is_nvfp4) { // NVFP4: construct output tensors with bulk allocations std::vector nvfp4_quantizers; @@ -799,6 +802,7 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); +#endif } else { NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); } diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 060a342a6..6d635e1c2 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,6 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -131,6 +134,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe impl = Impl::FUSED_NORM_AMAX_NVFP4; } } + #endif // Construct unquantized output tensor if needed TensorWrapper unquantized_out_nvte; @@ -148,12 +152,14 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#endif default: { } } @@ -191,10 +197,12 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#endif default: { } } @@ -344,6 +352,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; +#ifdef __HIP_PLATFORM_AMD__ + } +#else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); @@ -355,6 +366,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w impl = Impl::FUSED_NORM_AMAX_NVFP4; } } +#endif // Construct unquantized output tensor if needed TensorWrapper unquantized_out_nvte; @@ -372,12 +384,14 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; +#endif default: { } } @@ -413,10 +427,12 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#ifndef __HIP_PLATFORM_AMD__ case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; +#endif default: { } } diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 65665d01b..1c1855669 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -108,8 +108,12 @@ constexpr std::array custom_types_converters = { CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), +#ifdef __HIP_PLATFORM_AMD__ +}; +#else std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, CreateQuantizer)}; +#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index ffef3e59c..7240c3bf3 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1143,6 +1143,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +#ifndef __HIP_PLATFORM_AMD__ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); @@ -1719,5 +1720,6 @@ std::vector NVFP4Quantizer::get_scale_shape(const std::vector& s } return scale_shape; } +#endif } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 661cf3f2e..0ad1e86a4 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -51,7 +51,7 @@ from ..triton_kernels.cast import te_quantize_triton from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage -from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype +from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a906ea42d..89af05f93 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -444,7 +444,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 3fefb650e..fb89d6195 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -576,9 +576,9 @@ def forward( if is_grad_enabled: # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(fc1_weight_final, QuantizedTensorBase): + if isinstance(fc1_weight_final, QuantizedTensorStorage): fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorBase): + if isinstance(fc2_weight_final, QuantizedTensorStorage): fc2_weight_final.update_usage(columnwise_usage=True) if cpu_offloading: @@ -897,7 +897,7 @@ def backward( if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( - ctx.fc2_weight, QuantizedTensorStorage + fc2_weight, QuantizedTensorStorage ): fc2_weight.update_usage(columnwise_usage=True) @@ -1155,7 +1155,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - ctx.fc1_weight_quantizer, QuantizedTensorStorage + fc1_weight, QuantizedTensorStorage # this fixes a bug with upstream usage of fc1_weight_quantizer ): fc1_weight.update_usage(columnwise_usage=True) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index ffa47d986..0d43776f1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -410,7 +410,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading and saved_inputmat is not None: diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0b7eddb9f..915527736 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -260,6 +260,7 @@ class FP8GlobalStateManager: HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False + SKIP_FP8_REDUCTION_FOR_FSDP2 = False AUTOCAST_DEPTH = 0 global_amax_buffer = {} global_amax_history_buffer = {} diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 1848a60cf..f9ff4b77b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -142,7 +142,9 @@ def make_empty( columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: - columnwise_data = torch.empty_like(data) + columnwise_data = torch.empty( + shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) # ROCm TE does not implement fuse padding zeros so use zero tensor here columnwise_scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index b6a7270a3..eae6bb79c 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -10,11 +10,11 @@ from ..utils import is_non_tn_fp8_gemm_supported -from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from .cast_transpose import te_cast_transpose_mxfp8_triton, te_cast_transpose_noop_triton, te_dequantize_mxfp8_triton import transformer_engine_torch as tex -from ..tensor.quantized_tensor import QuantizedTensor, Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..quantized_tensor import QuantizedTensor, Quantizer +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: @@ -72,7 +72,7 @@ def te_quantize_triton( _setup_conditional_transpose_storage(out) else: out = quantizer.make_empty(input_tensor.shape, dtype=fake_tensor_type) - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): _setup_conditional_transpose_storage(out) else: # Create a QuantizedTensor from the provided output tensor @@ -82,11 +82,11 @@ def te_quantize_triton( if noop_flag is None: noop_flag = _empty_tensor() # if it's mxfp8, we'll check if both rowwise and columnwise data are none - if (isinstance(out, MXFP8TensorBase) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorBase) and out.size().numel() == 0): + if (isinstance(out, MXFP8TensorStorage) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorStorage) and out.size().numel() == 0): # Return empty output if the quantized tensor has no elements return out - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): if input_tensor.nelement() > 0: if not out._transpose_invalid: quantizer = out._get_quantizer() @@ -117,7 +117,7 @@ def te_quantize_triton( else: out.remove_caches() #Make sure to remove transpose if it is marked as invalid out = tex.quantize(input_tensor, quantizer, out, noop_flag) - elif isinstance(out, MXFP8TensorBase): + elif isinstance(out, MXFP8TensorStorage): te_cast_transpose_mxfp8_triton(input_tensor, out) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(out).__name__}'") @@ -125,9 +125,9 @@ def te_quantize_triton( return out def te_dequantize_triton(input, dtype: tex.DType): - if isinstance(input, MXFP8TensorBase): + if isinstance(input, MXFP8TensorStorage): return te_dequantize_mxfp8_triton(input, dtype) - elif isinstance(input, Float8TensorBase): + elif isinstance(input, Float8TensorStorage): return tex.dequantize(input, dtype) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(input).__name__}'") diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 3baa64697..37f504e4c 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -10,7 +10,7 @@ from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..constants import TE_DType from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer from ..triton_kernels.cast import te_quantize_triton import triton import triton.language as tl diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 9f152582e..3acbf0835 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -13,7 +13,7 @@ te_dtype_to_triton_dtype, ) from .common import get_fp8_max -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer import transformer_engine_torch as tex def dg_tmp_rows(x, sm_margin=None): From 8f94cf652f6989483e4525f8f3812c6046fa9543 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 09:47:40 -0600 Subject: [PATCH 136/141] jax and pytorch bugfix --- tests/jax/test_custom_call_compute.py | 15 ++++++++------- tests/jax/test_fused_attn.py | 4 ++-- tests/jax/utils.py | 2 +- tests/pytorch/attention/run_attention_with_cp.py | 4 +++- tests/pytorch/attention/test_attention_with_cp.py | 7 ++++--- .../jax/cpp_extensions/attention.py | 1 + transformer_engine/jax/cpp_extensions/gemm.py | 6 +++--- .../jax/cpp_extensions/normalization.py | 1 + transformer_engine/jax/csrc/extensions/amax.cpp | 4 ++++ transformer_engine/jax/csrc/extensions/gemm.cpp | 6 ++++++ .../jax/csrc/extensions/quantization.cpp | 2 ++ transformer_engine/pytorch/distributed.py | 6 ++++++ 12 files changed, 41 insertions(+), 17 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 75b606a62..3b9ee0034 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -636,7 +636,7 @@ def test_norm_forward_with_block_scaling_fp8( ) -QUANTIZE_OUTPUT_DTYPES = { +QUANTIZE_OUTPUT_FP8_DTYPES = { "L0": [jnp_float8_e4m3_type], "L2": FP8_COMPUTE_TYPE, } @@ -1790,11 +1790,12 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) - num_gemms = input_shape[0] - _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( - group_sizes, - num_gemms=num_gemms, - ) + if not is_hip_extension(): + num_gemms = input_shape[0] + _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( + group_sizes, + num_gemms=num_gemms, + ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm @@ -1805,7 +1806,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): rhs, group_sizes, contracting_dims, - use_async_d2h_group_sizes=True, + use_async_d2h_group_sizes=not is_hip_extension(), ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f33961455..a08a1fe42 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1026,14 +1026,14 @@ def check_dqkv(primitive, reference, pad, idx): ), pytest.param( 2, - 512, + 2048, 1024, 12, 12, 64, 64, jnp.bfloat16, - id="2-512-1024-12-12-64-64-BF16-CROSS", + id="2-2048-1024-12-12-64-64-BF16-CROSS", ), pytest.param( 2, 2048, 2048, 12, 6, 64, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-64-BF16-GQA" diff --git a/tests/jax/utils.py b/tests/jax/utils.py index a0e5e708b..60b01348f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 80f21048f..b59fe6451 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -451,9 +451,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[4] = tensors_to_deq - for tensor in tensors: + i = 0 + for tensor in tensors[4:]: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) + i += 1 out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 9ac96dcef..d0956c226 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -24,6 +24,7 @@ Float8CurrentScaling, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) @@ -144,8 +145,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: ModelConfig(b, sq, hq, dqk) - "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=True), # MHA - "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=True), # MHA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=not IS_HIP_EXTENSION), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=not IS_HIP_EXTENSION), # MHA "cp_1_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA @@ -305,7 +306,7 @@ def test_cp_with_fused_attention( ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, - qkv_dtype=dtypes[dtype] if dtype != "fp8" else torch.float8_e4m3fn, + qkv_dtype=dtypes[dtype] if dtype != "fp8" else get_torch_float8_e4m3_type(), qkv_layout="_".join([qkv_format] * 3), fp8=fp8, fp8_meta=fp8_meta, diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 428f3ba2e..ab2a4562e 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -9,6 +9,7 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce +from packaging import version from typing import Optional, Tuple import jax diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 756913c91..a04a98d97 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -24,8 +24,8 @@ get_num_compute_streams, JAXX_Collective_Op, get_device_compute_capability, - initialize_cgemm_communicator, - get_cgemm_num_max_streams, + #initialize_cgemm_communicator, + #get_cgemm_num_max_streams, ) from .base import BasePrimitive, register_primitive @@ -83,7 +83,7 @@ def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if is_hip_extension(): """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) == 95: + if get_device_compute_capability(0) == 95: return 67_108_864 return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 1bf6ec943..e53d63625 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -8,6 +8,7 @@ import warnings import operator from functools import partial, cache, reduce +from packaging import version from typing import Optional, Union import jax diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 46f167fca..a4b590250 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -1,8 +1,11 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#ifndef __HIP_PLATFORM_AMD__ #include #include @@ -98,3 +101,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( } // namespace jax } // namespace transformer_engine +#endif // #ifndef __HIP_PLATFORM_AMD__ \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 41b78f117..f038101b2 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -87,6 +87,9 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } } else { // Swizzle for NVFP4 NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); +#ifdef __HIP_PLATFORM_AMD__ + } +#else input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); @@ -97,6 +100,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); } +#endif // #ifdef __HIP_PLATFORM_AMD__ } return std::make_tuple(std::move(input), input_shape); @@ -767,6 +771,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); +#ifndef __HIP_PLATFORM_AMD__ if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, @@ -778,6 +783,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); } } +#endif // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 1f7db8438..626c47276 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -177,6 +177,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } if (is_quantize_colwise(quantize_layout)) { +#ifndef __HIP_PLATFORM_AMD__ if (is_nvfp4 && use_rht) { if (is_quantize_2x2x(quantize_layout)) { // Do regular rowwise quantization without RHT @@ -218,6 +219,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T return ffi_with_cuda_error_check(); } +#endif // #ifndef __HIP_PLATFORM_AMD__ bool const is_colwise_transposed = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c7329caca..04ffa324d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -952,7 +952,13 @@ def _all_gather_fp8( if isinstance(inp, Float8Tensor): dtype = inp.dtype device = inp.device + # Temporarily ensure rowwise usage for output tensor creation + # since we're gathering rowwise data, not the transpose + init_rowwise_usage = quantizer.rowwise_usage + init_columnwise_usage = quantizer.columnwise_usage + quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage) out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) out._data = torch.empty( From bac79938e79698010f749470bb802260a76a64f0 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 10:41:15 -0600 Subject: [PATCH 137/141] copyrights and fp8_autocast->autocast fix --- build_tools/jax.py | 2 +- build_tools/pytorch.py | 2 +- build_tools/utils.py | 2 +- tests/cpp/operator/CMakeLists.txt | 2 +- tests/pytorch/distributed/test_fusible_ops.py | 2 +- tests/pytorch/test_fused_optimizer.py | 2 +- tests/pytorch/triton_kernels/test_cast.py | 10 +++++----- transformer_engine/common/CMakeLists.txt | 2 +- transformer_engine/common/__init__.py | 2 +- transformer_engine/common/cast/cast.cu | 2 ++ transformer_engine/common/cast/dispatch/dequantize.cuh | 2 ++ transformer_engine/common/cast/fp8/dequantize_fp8.cuh | 2 ++ transformer_engine/common/cast/fp8/gated_fp8.cuh | 2 ++ .../common/cast/mxfp8/dequantize_mxfp8.cuh | 2 +- transformer_engine/common/common.h | 2 +- .../common/fused_attn_rocm/fused_attn.cpp | 2 +- .../common/fused_attn_rocm/fused_attn_aotriton.cpp | 2 +- .../common/fused_attn_rocm/fused_attn_aotriton.h | 2 +- .../common/fused_attn_rocm/fused_attn_ck.cpp | 2 +- .../common/fused_attn_rocm/fused_attn_ck.h | 2 +- transformer_engine/common/gemm/cublaslt_gemm.cu | 2 +- transformer_engine/common/gemm/rocm_gemm.cu | 2 +- transformer_engine/common/recipe/__init__.py | 2 +- transformer_engine/common/util/logging.h | 2 +- transformer_engine/jax/csrc/extensions.h | 2 +- transformer_engine/jax/csrc/extensions/pybind.cpp | 2 +- transformer_engine/jax/quantize/helper.py | 2 +- .../attention/dot_product_attention/backends.py | 2 +- .../dot_product_attention/context_parallel.py | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- .../pytorch/csrc/extensions/activation.cpp | 2 +- transformer_engine/pytorch/csrc/extensions/cast.cpp | 2 +- .../pytorch/csrc/extensions/normalization.cpp | 2 +- transformer_engine/pytorch/module/layernorm_mlp.py | 2 +- transformer_engine/pytorch/triton_kernels/layernorm.py | 2 +- transformer_engine/pytorch/triton_kernels/rmsnorm.py | 2 +- 36 files changed, 44 insertions(+), 36 deletions(-) diff --git a/build_tools/jax.py b/build_tools/jax.py index e67036f49..6f2e57f87 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 35d910bcd..4bbefd730 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/build_tools/utils.py b/build_tools/utils.py index 0c34bedde..0c18d7ecf 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index cd36993ce..d3b75bbbf 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 85ae2d85b..b9fbfb2a5 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index eeeda171e..5526103d5 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/triton_kernels/test_cast.py b/tests/pytorch/triton_kernels/test_cast.py index 3f725c496..f85773d65 100644 --- a/tests/pytorch/triton_kernels/test_cast.py +++ b/tests/pytorch/triton_kernels/test_cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import pytest @@ -10,7 +10,7 @@ from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype import transformer_engine_torch as tex from test_common import te_compare_results, fill_uniform, get_tolerances -from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.pytorch.fp8 import autocast from transformer_engine.common import recipe from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type @@ -43,7 +43,7 @@ def test_quantize(scaling, shape, in_dtype, out_dtype): triton_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda") tex_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer) quantized_out_tex = tex.quantize(input_tensor, tex_quantizer) @@ -187,13 +187,13 @@ def test_amax_atomic_vs_two_stage(shape, in_dtype, out_dtype): # atomic amax os.environ[env_key] = "1" - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): out_atomic = te_quantize_triton(input_tensor, quantizer=quantizer_atomic) # 2-stage amax os.environ[env_key] = "0" - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): out_2stage = te_quantize_triton(input_tensor, quantizer=quantizer_2stage) te_compare_results( diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 46eb5dba5..831de2b45 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index cdda37508..f0335f44c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 575106a53..7ecc05d2e 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh index f55719852..4ba64ca97 100644 --- a/transformer_engine/common/cast/dispatch/dequantize.cuh +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh index 5d30a6c3f..22a3929e3 100644 --- a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index c9040a3da..aa46a574c 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 96aed3e88..5701a446d 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 03b90febb..5feeb600c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index d39fccbce..48a309118 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index 1c25fa031..9109ddb15 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index 178bd8d8f..fd4dffd73 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 8d639c47c..02bc9ce94 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h index 926c90866..0772609ff 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 343b3cecb..0127a9edf 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 97bd2e8a7..205a0a058 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 674d4e4cb..b90cd5ce3 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 09187069e..ebcf99afe 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 845176080..2bfa4c89f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index bc47ef6bd..937dde228 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 792173ed1..b8a8809fc 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 5437b73bc..a0aaab1f3 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 13b41345b..096eca809 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 55d8aafd6..fd83f20d4 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index ebdbb5817..6936d6bc8 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 97b9e7ca6..f3c77a332 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 6d635e1c2..805579ff4 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index fb89d6195..fb3327156 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 37f504e4c..86b7b46c7 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 3acbf0835..1ca6183c9 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import torch From 8ae38e8e66df970c2cc165dee6d7f0bd98d35250 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 24 Feb 2026 13:35:14 -0600 Subject: [PATCH 138/141] Enable test_distributed_dense.py --- ci/jax.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/jax.sh b/ci/jax.sh index 81d994585..d350ebac7 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -86,6 +86,7 @@ run_test_config_mgpu() { if [ "$TEST_LEVEL" -le 3 ]; then TEST_ERROR_IGNORE="1" fi + run_default_fa 2 test_distributed_dense.py run $_dfa_level test_distributed_fused_attn.py $_timeout_args TEST_ERROR_IGNORE="" run_default_fa 3 test_distributed_layernorm.py From 05a977a8e089c4e3beb308267ce0d8cc3a6416d9 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Mar 2026 00:18:46 -0600 Subject: [PATCH 139/141] address IFU comments --- build_tools/wheel_utils/build_wheels.sh | 8 +- ci/jax.sh | 12 +-- setup.py | 7 +- tests/jax/distributed_test_base.py | 2 - tests/jax/test_fused_attn.py | 1 + .../attention/run_attention_with_cp.py | 1 - tests/pytorch/utils.py | 4 +- .../common/cast/core/common.cuh | 2 + .../common/cast/mxfp8/dequantize_mxfp8.cuh | 2 + .../common/cast/mxfp8/gated_mxfp8.cuh | 2 + .../common/cast/mxfp8/quantize_mxfp8.cuh | 2 + .../cast/mxfp8/rocm_dequantize_mxfp8.cuh | 7 +- .../common/cast/mxfp8/rocm_gated_mxfp8.cuh | 14 ++-- .../common/cast/mxfp8/rocm_quantize_mxfp8.cuh | 8 +- transformer_engine/common/common.cu | 2 +- .../common/fused_attn_rocm/fused_attn.cpp | 1 + transformer_engine/common/recipe/__init__.py | 32 +++----- .../common/recipe/current_scaling.cu | 2 +- transformer_engine/common/swizzle/swizzle.cu | 3 +- transformer_engine/common/util/ptx.cuh | 73 ------------------- .../common/util/rocm_vectorized_2d.cuh | 13 ---- .../jax/cpp_extensions/attention.py | 4 +- transformer_engine/jax/cpp_extensions/gemm.py | 3 +- .../jax/csrc/extensions/amax.cpp | 4 +- .../jax/csrc/extensions/cgemm_helper.h | 4 + .../jax/csrc/extensions/gemm.cpp | 11 +-- transformer_engine/jax/csrc/extensions/misc.h | 2 + .../jax/csrc/extensions/quantization.cpp | 4 +- transformer_engine/jax/quantize/helper.py | 2 + transformer_engine/jax/setup.py | 5 +- .../dot_product_attention/backends.py | 4 + .../dot_product_attention.py | 2 + .../attention/dot_product_attention/utils.py | 7 +- .../pytorch/cpp_extensions/fused_attn.py | 4 + transformer_engine/pytorch/csrc/common.cpp | 4 +- .../pytorch/csrc/extensions/activation.cpp | 8 +- .../pytorch/csrc/extensions/bias.cpp | 4 +- .../pytorch/csrc/extensions/cast.cpp | 6 +- .../pytorch/csrc/extensions/normalization.cpp | 12 +-- .../pytorch/csrc/extensions/recipe.cpp | 2 +- transformer_engine/pytorch/csrc/pybind.h | 2 + transformer_engine/pytorch/csrc/quantizer.cpp | 4 +- transformer_engine/pytorch/fp8.py | 2 - transformer_engine/pytorch/module/base.py | 2 +- transformer_engine/pytorch/quantization.py | 2 + transformer_engine/pytorch/setup.py | 12 ++- .../pytorch/tensor/float8_tensor.py | 2 +- transformer_engine/pytorch/utils.py | 4 +- 48 files changed, 130 insertions(+), 190 deletions(-) delete mode 100644 transformer_engine/common/util/rocm_vectorized_2d.cuh diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 0be852c8a..0e8ab68a8 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -75,12 +75,12 @@ if $BUILD_COMMON ; then ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt if [ "$ROCM_BUILD" = "1" ]; then - # Repack the wheel for cuda specific package, i.e. cu12. + # Repack the wheel for specific rocm package. ${PYBINDIR}wheel unpack dist/* # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" + sed -i "s/Name: transformer-engine/Name: transformer-engine-rocm/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_rocm/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_rocm-${VERSION}.dist-info" ${PYBINDIR}wheel pack ${WHL_BASE} else # Repack the wheel for specific cuda version. diff --git a/ci/jax.sh b/ci/jax.sh index d350ebac7..ef9dbe124 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -59,8 +59,7 @@ run_test_config() { run_default_fa 1 test_functions.py run 1 test_fused_attn.py NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass - run_default_fa 1 test_helper.py - run_default_fa 1 test_layer.py #it effectevly always uses unfused attention + run_default_fa 1 test_layer.py # it effectively always uses unfused attention run_default_fa 1 test_sanity_import.py run_default_fa 1 test_softmax.py } @@ -71,7 +70,7 @@ run_test_config_mgpu() { # Mitigate distributed tests hang by adding 5min timeout _timeout_args="--timeout 300 --timeout-method thread" - # Workaround for some distributed tests hang/abotrion + # Workaround for some distributed tests hang/abortion export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then @@ -81,12 +80,13 @@ run_test_config_mgpu() { _dfa_level=3 export NVTE_JAX_UNITTEST_LEVEL=L2 fi + + run_default_fa 2 test_distributed_dense.py # Do not fail automated CI if test_distributed_fused_attn is hung - # If the sctipt run w/o TEST_LEVEL the test error will be honored + # If the script runs w/o TEST_LEVEL the test error will be honored if [ "$TEST_LEVEL" -le 3 ]; then TEST_ERROR_IGNORE="1" fi - run_default_fa 2 test_distributed_dense.py run $_dfa_level test_distributed_fused_attn.py $_timeout_args TEST_ERROR_IGNORE="" run_default_fa 3 test_distributed_layernorm.py @@ -96,7 +96,7 @@ run_test_config_mgpu() { run_default_fa 3 test_sanity_import.py } -# Single config mode, run it synchroniously and return result +# Single config mode, run it synchronously and return result if [ -n "$SINGLE_CONFIG" ]; then _fus_attn="$SINGLE_CONFIG" configure_fused_attn_env $_fus_attn && run_test_config diff --git a/setup.py b/setup.py index bec4943e1..eb241f5cb 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ import time from pathlib import Path from typing import List, Tuple -import subprocess import setuptools from setuptools.command.egg_info import egg_info @@ -240,7 +239,7 @@ def git_check_submodules() -> None: assert bool( int(os.getenv("NVTE_RELEASE_BUILD", "0")) ), "NVTE_RELEASE_BUILD env must be set for metapackage build." - te_cuda_vers = "rocm" if rocm_build() else "cu12" + te_cuda_vers = "cu12" ext_modules = [] cmdclass = {} package_data = {} @@ -253,7 +252,7 @@ def git_check_submodules() -> None: "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } if not rocm_build() else { - "core": [f"transformer_engine_{te_cuda_vers}=={__version__}"], + "core": [f"transformer_engine_rocm=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } @@ -303,7 +302,7 @@ def git_check_submodules() -> None: long_description=long_description, long_description_content_type="text/x-rst", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist} if not rocm_build() else {"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, + cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, python_requires=f">={min_python_version_str()}", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index e8d9cefd6..137fa480d 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index a08a1fe42..72797f556 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -387,6 +387,7 @@ def _check_configs(self): get_device_compute_capability(0) >= 100 and self.dropout_prob == 0.1 and self.attn_bias_type is not AttnBiasType.NO_BIAS + and not is_hip_extension() ): pytest.skip( "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index b59fe6451..9e59f4f6a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -455,7 +455,6 @@ def run_dpa_with_cp( for tensor in tensors[4:]: assert torch.all(~torch.isnan(tensor)) assert torch.all(~torch.isinf(tensor)) - i += 1 out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors ############ compare results between CP and no-CP ############ diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 05555626b..ed5a12995 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -100,9 +100,9 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) - if dtype in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + if dtype in torch_float8_e4m3_type: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype in (torch.float8_e5m2, torch.float8_e5m2fnuz): + if dtype in torch_float8_e5m2_type: return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index ec36e941f..540f7e252 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index 5701a446d..38eead606 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -25,6 +25,8 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" +#include "./rocm_vectorized_2d.cuh" + namespace transformer_engine { namespace dispatch { namespace mxfp8 { diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 28e46fc7a..69e30680c 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -25,6 +25,8 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" +#include "./rocm_vectorized_2d.cuh" + namespace transformer_engine { namespace dispatch { namespace mxfp8 { diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 19234e9b4..8e25b3f65 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -26,6 +26,8 @@ #include "../../utils.cuh" #include "../core/common.cuh" +#include "./rocm_vectorized_2d.cuh" + namespace transformer_engine { namespace dispatch { namespace mxfp8 { diff --git a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index 02224a69f..49c57737c 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -67,7 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -108,9 +108,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); - ptx::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, rows, cols); + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, + chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } diff --git a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index 7382b8aab..a8c02e4f8 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -122,16 +122,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate bulk tensor copy if constexpr (IS_DGATED) { - copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } // Act - copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); // Gate - copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -356,19 +356,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - ptx::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - ptx::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } diff --git a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index dc36fb42d..d5b51a2f4 100644 --- a/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -143,11 +143,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, + transformer_engine::rocm::copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -290,12 +290,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } if constexpr (USE_COLWISE_SCALING) { - ptx::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index ab574256c..a0193e95f 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index 48a309118..97aecf4de 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -282,6 +282,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; + NVTE_CHECK(!(return_max_logit || cuda_graph), "ROCm does not support return_max_logit and cuda_graph for fused_attn yet."); // by default, fused attn is enabled bool nvte_fused_attn = true; if (const char* env_p = std::getenv("NVTE_FUSED_ATTN") ) { diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b90cd5ce3..ae86e492d 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -18,24 +18,10 @@ class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. """ - fwd: tuple - bwd: tuple - @property - def max_fwd(self) -> float: - return self.fwd[is_fp8_fnuz()] + max_fwd: float + max_bwd: float - @property - def max_bwd(self) -> float: - return self.bwd[is_fp8_fnuz()] - -class _FormatMaxVals(Enum): - """ - Tuples of FP8 (OCP, FNUZ) values for different formats. - """ - E2M1 = (6, 6) - E4M3 = (448, 240) - E5M2 = (57344, 57344) class Format(Enum): """ @@ -54,11 +40,15 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - #TODO: Change max vals after rocm support MXFP4 - E2M1 = _FormatHelper(fwd=_FormatMaxVals.E2M1.value, bwd=_FormatMaxVals.E2M1.value) - E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E5M2.value) + + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) + if te_rocm_build: + max_e4m3_val = 240 if is_fp8_fnuz() else 448 + E4M3 = _FormatHelper(max_fwd=max_e4m3_val, max_bwd=max_e4m3_val) + else: + E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) + E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) + HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) @dataclass(frozen=True) diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index 69b44494b..55aa2907e 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 881b134e7..a00c30a9c 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -24,8 +24,8 @@ namespace { #define __ldg(x) (*(x)) #endif -#ifndef __HIP_PLATFORM_AMD__ constexpr int MXFP8_BLOCK_SIZE = 32; +#ifndef __HIP_PLATFORM_AMD__ constexpr int NVFP4_BLOCK_SIZE = 16; constexpr __device__ __host__ int TB_DIM = 32; @@ -38,7 +38,6 @@ constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; #else // HIPCC does not support __host__ qualifier for variables // and constexpr values do not need __device__ qualifier because they are compile-time constants -constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int TB_DIM = 32; constexpr int NEW_SF_TILE_DIM_K = 16; constexpr int N_SF_PER_TD_PER_TILE = 4; diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index ef53c2670..312890db0 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -318,38 +318,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ #endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } -#ifdef __HIP_PLATFORM_AMD__ -template -__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); - - T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); - - shared_loader.load(l_x_vec, chunk_dim_x); - - if (g_row < total_rows) { - global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; - global_storer.store(g_col_primitive_start / N_VEC, total_cols); - } - } -} -#endif //#ifdef __HIP_PLATFORM_AMD__ - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( @@ -931,47 +899,7 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -#ifdef __HIP_PLATFORM_AMD__ -// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. -template -__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, - size_t g_start_row, size_t g_stride, size_t chunk_dim_y, - size_t chunk_dim_x, size_t total_rows, - size_t total_cols) { - size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; - const size_t l_idx = threadIdx.x; - - for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { - size_t l_y = (i_vec / chunk_dim_x_vec_elements); - size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); - - size_t g_row = g_start_row + l_y; - size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; - - if (g_row < total_rows) { - const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; - VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); - - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); - - global_loader.load(g_col_primitive_start / N_VEC, total_cols); - shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; - shared_storer.store(l_x_vec, chunk_dim_x); - } else { - T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; - VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); - -#pragma unroll - for (int i = 0; i < N_VEC; ++i) { - shared_storer.separate()[i] = static_cast(0); - } - shared_storer.store(l_x_vec, chunk_dim_x); - } - } -} -#else __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, const size_t chunk_Y, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { @@ -992,7 +920,6 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -#endif //#ifdef __HIP_PLATFORM_AMD__ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, const size_t chunk_X1, const size_t chunk_Y1, diff --git a/transformer_engine/common/util/rocm_vectorized_2d.cuh b/transformer_engine/common/util/rocm_vectorized_2d.cuh deleted file mode 100644 index eda0f437f..000000000 --- a/transformer_engine/common/util/rocm_vectorized_2d.cuh +++ /dev/null @@ -1,13 +0,0 @@ -/************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - * - * License for AMD contributions = MIT. See LICENSE for more information - ************************************************************************/ - -#pragma once - -#include "../util/vectorized_pointwise.h" - -namespace transformer_engine { - -} // namespace transformer_engine diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index ab2a4562e..ba6d01a9f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -2786,7 +2786,7 @@ def fused_attn_bwd( # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on # sm100+ compute_capabilities = get_all_device_compute_capability() - if any(x >= 100 for x in compute_capabilities): + if any(x >= 100 for x in compute_capabilities) and not is_hip_extension(): assert not ( attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index a04a98d97..2daecedfa 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -285,7 +285,8 @@ def collective_gemm_bootstrap( and before any collective GEMM operations. Each process should call this function with its own unique process_id. """ - + if is_hip_extension(): + assert 0, "collective_gemm_bootstrap is not supported for ROCm yet." assert ( num_devices_per_process == 1 and jax.local_device_count() == 1 ), "Only single device per process is supported at the moment!" diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index a4b590250..050d0fd23 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -5,7 +5,7 @@ * * See LICENSE for license information. ************************************************************************/ -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM #include #include @@ -101,4 +101,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( } // namespace jax } // namespace transformer_engine -#endif // #ifndef __HIP_PLATFORM_AMD__ \ No newline at end of file +#endif // #ifndef USE_ROCM \ No newline at end of file diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h index 84b2b8154..03d86c168 100644 --- a/transformer_engine/jax/csrc/extensions/cgemm_helper.h +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -7,6 +9,7 @@ #ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ #define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ +#ifndef USE_ROCM #include #include @@ -186,4 +189,5 @@ int GetCgemmNumMaxStreams(); } // namespace jax } // namespace transformer_engine +#endif // #ifndef USE_ROCM #endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f038101b2..d35b2d072 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -86,10 +86,11 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } } else { // Swizzle for NVFP4 - NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM + NVTE_ERROR("ROCm TE does not support NVFP4 yet."); } #else + NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); // Create tensor to hold swizzled scale factor TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); @@ -100,7 +101,7 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set swizzled scales into the input tensor input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); } -#endif // #ifdef __HIP_PLATFORM_AMD__ +#endif // #ifdef USE_ROCM } return std::make_tuple(std::move(input), input_shape); @@ -285,7 +286,7 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } else { #ifdef USE_ROCM //TODO: better assert - std::cerr<<"ROCm TE jax does not integrate userbuffer for now"< buffer_shape{0, 0}; DType buffer_dtype = out_dtype; @@ -771,7 +772,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index a0c5db5a8..c2d3d6f25 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 626c47276..9a0a87d69 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -177,7 +177,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T } if (is_quantize_colwise(quantize_layout)) { -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM if (is_nvfp4 && use_rht) { if (is_quantize_2x2x(quantize_layout)) { // Do regular rowwise quantization without RHT @@ -219,7 +219,7 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T return ffi_with_cuda_error_check(); } -#endif // #ifndef __HIP_PLATFORM_AMD__ +#endif // #ifndef USE_ROCM bool const is_colwise_transposed = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index b8a8809fc..95d5aea21 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -134,6 +134,8 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: def _check_fp4_support(gpu_arch) -> Tuple[bool, str]: """Check if FP4 is supported for the given GPU architecture.""" + if is_hip_extension(): + return False, "FP4 not yet supported for ROCm" if gpu_arch < 100: # pre-blackwell return False, "Device compute capability 10.0 or higher required for NVFP4 execution." if get_cublasLt_version() < 120800: diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index 0b958d3ad..619b6070b 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -122,7 +122,10 @@ def get_cuda_major_version() -> int: # us to detect CUDA version dynamically during compilation and # choose the correct wheel for te core lib. __version__ = te_version() - te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" + if not rocm_build(): + te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" + else: + te_core = f"transformer_engine_rocm=={__version__}" install_requires = install_requirements() + [te_core] # Configure package diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a0aaab1f3..038ebc3c0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -219,6 +219,8 @@ def __init__( softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.softmax_scale = softmax_scale @@ -1676,6 +1678,8 @@ def __init__( softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.softmax_scale = softmax_scale diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 4157e8d3a..ef601e4c4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -319,6 +319,8 @@ def __init__( softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.logger = logging.getLogger("DotProductAttention") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index c8da3161b..54fe21d81 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -482,7 +482,7 @@ def get_attention_backend( fp8_recipe = fp8_meta["recipe"] if fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] - if use_fused_attention and fp8_recipe.float8_current_scaling(): + if use_fused_attention and fp8_recipe.float8_current_scaling() and not IS_HIP_EXTENSION: if device_compute_capability < (10, 0): logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") use_fused_attention = False @@ -502,7 +502,7 @@ def get_attention_backend( ) use_fused_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and not IS_HIP_EXTENSION: if use_flash_attention: logger.debug( "Disabling FlashAttention as FP8 is not supported" @@ -599,6 +599,7 @@ def get_attention_backend( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) and is_training + and not IS_HIP_EXTENSION ): if use_fused_attention: logger.debug( @@ -679,7 +680,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and not IS_HIP_EXTENSION: if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 852dcdb59..e5492ebc6 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -264,6 +264,10 @@ def fused_attn_fwd( max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None """ + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." + assert not cuda_graph, "ROCm does not support cuda_graph." + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index e1a78d49a..59f57743b 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,7 +12,7 @@ #include "pybind.h" #include "transformer_engine/transformer_engine.h" -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM #include "common/common.h" #endif @@ -312,7 +312,7 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM inline bool nvte_use_atomic_amax() { const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX"); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 6936d6bc8..205605312 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -41,7 +41,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -105,7 +105,7 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { @@ -159,7 +159,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { impl = Impl::FUSED_ACTIVATION_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -223,7 +223,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_ACTIVATION_AMAX_NVFP4: // Compute activation and amax in high precision, then quantize to NVFP4 { diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 1d3e27a14..e8a735966 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -151,7 +151,7 @@ std::vector dact_dbias( impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { @@ -224,7 +224,7 @@ std::vector dact_dbias( fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_DACT_AMAX_NVFP4: // Fused dact-amax kernel, unfused dbias and NVFP4 quantize { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index f3c77a332..8fc4e1e97 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -493,7 +493,7 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM // allocate fp4 data, fp8 scalings, and amax values // layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] // amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate @@ -694,7 +694,7 @@ std::tuple, std::vector> bulk_allocate_nv return retval; } -#endif // #ifndef __HIP_PLATFORM_AMD__ +#endif // #ifndef USE_ROCM } // namespace @@ -793,7 +793,7 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM } else if (is_nvfp4) { // NVFP4: construct output tensors with bulk allocations std::vector nvfp4_quantizers; diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 805579ff4..839bb694a 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,7 +120,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -152,7 +152,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = @@ -197,7 +197,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); @@ -352,7 +352,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); impl = Impl::FUSED_NORM_AMAX_FP8; -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM } #else } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { @@ -384,7 +384,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); kernel_out_nvte = &unquantized_out_nvte; } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); std::tie(unquantized_out_nvte, unquantized_out) = @@ -427,7 +427,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); } break; -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM case Impl::FUSED_NORM_AMAX_NVFP4: { auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index fafdc3761..577a938f2 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -28,7 +28,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { DType::kFloat8E4M3, // It doesn't matter because we only compute amax. amax_ptr); -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM at::Tensor ws = allocate_amax_workspace(te_input); TensorWrapper tw = makeTransformerEngineTensor(ws); nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 1c1855669..a84641364 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7240c3bf3..90ed2a99f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -516,7 +516,7 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM at::Tensor ws = allocate_amax_workspace(input); TensorWrapper tw = makeTransformerEngineTensor(ws); NVTE_SCOPED_GIL_RELEASE({ @@ -1143,7 +1143,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } -#ifndef __HIP_PLATFORM_AMD__ +#ifndef USE_ROCM NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->with_rht = quantizer.attr("with_rht").cast(); diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index b36302db2..f937b3de9 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 0ad1e86a4..634c188ce 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 915527736..d8dff33d5 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index 73f926c61..12a87d4bd 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -155,9 +155,13 @@ def run(self): # us to detect CUDA version dynamically during compilation and # choose the correct wheel for te core lib. __version__ = te_version() - cuda_major_version = parse(torch.version.cuda).major - assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}." - te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" + if not rocm_build(): + cuda_major_version = parse(torch.version.cuda).major + assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}." + te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" + install_requires = install_requirements() + [te_core] + else: + te_core = f"transformer_engine_rocm=={__version__}" install_requires = install_requirements() + [te_core] # Configure package diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 8f741b7f2..316733e31 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 86acb7932..12c62437d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -21,6 +21,8 @@ __all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] +if IS_HIP_EXTENSION: + __all__.extend(["is_mi200", "is_mi308", "is_fp8_fnuz"]) def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" From 0385852c1f825f14410a4cb071e256050f568134 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Mar 2026 14:53:26 -0600 Subject: [PATCH 140/141] _FormatHelperFP8 and missing file add --- .../common/cast/mxfp8/rocm_vectorized_2d.cuh | 81 +++++++++++++++++++ transformer_engine/common/recipe/__init__.py | 33 +++++--- .../jax/csrc/extensions/amax.cpp | 2 +- 3 files changed, 106 insertions(+), 10 deletions(-) create mode 100644 transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh diff --git a/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh new file mode 100644 index 000000000..50474f308 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh @@ -0,0 +1,81 @@ +/************************************************************************* + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * + * License for AMD contributions = MIT. See LICENSE for more information + ************************************************************************/ + +#pragma once + +#include "../../util/vectorized_pointwise.h" + +namespace transformer_engine::rocm { +// These 2d copy functions replace TMA tensormap async copies for AMD GPUs. +template +__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + if (g_row < total_rows) { + const T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedLoaderglobal_loader(current_g_row_base_ptr, total_cols); + + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorershared_storer(current_sh_row_base_ptr, chunk_dim_x); + + global_loader.load(g_col_primitive_start / N_VEC, total_cols); + shared_storer.storage_.scratch_ = global_loader.storage_.scratch_; + shared_storer.store(l_x_vec, chunk_dim_x); + + } else { + T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedStorer shared_storer(current_sh_row_base_ptr, chunk_dim_x); + +#pragma unroll + for (int i = 0; i < N_VEC; ++i) { + shared_storer.separate()[i] = static_cast(0); + } + shared_storer.store(l_x_vec, chunk_dim_x); + } + } +} + +template +__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col, + size_t g_start_row, size_t g_stride, size_t chunk_dim_y, + size_t chunk_dim_x, size_t total_rows, + size_t total_cols) { + const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC; + const size_t l_idx = threadIdx.x; + + for (size_t i_vec = l_idx; i_vec < chunk_dim_y * chunk_dim_x_vec_elements; i_vec += blockDim.x) { + size_t l_y = (i_vec / chunk_dim_x_vec_elements); + size_t l_x_vec = (i_vec % chunk_dim_x_vec_elements); + + size_t g_row = g_start_row + l_y; + size_t g_col_primitive_start = g_start_col + l_x_vec * N_VEC; + + const T* current_sh_row_base_ptr = sh_ptr_base + l_y * chunk_dim_x; + VectorizedLoader shared_loader(current_sh_row_base_ptr, chunk_dim_x); + + T* current_g_row_base_ptr = g_ptr + g_row * g_stride; + VectorizedStorer global_storer(current_g_row_base_ptr, total_cols); + + shared_loader.load(l_x_vec, chunk_dim_x); + + if (g_row < total_rows) { + global_storer.storage_.scratch_ = shared_loader.storage_.scratch_; + global_storer.store(g_col_primitive_start / N_VEC, total_cols); + } + } +} +} // namespace transformer_engine::rocm \ No newline at end of file diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index ae86e492d..223f7a720 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -18,10 +18,30 @@ class _FormatHelper(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. """ - max_fwd: float max_bwd: float +class _FormatHelperFP8(NamedTuple): + """ + Stores max FP8 values for fprop and bprop a `Format`. + """ + fwd: tuple + bwd: tuple + + @property + def max_fwd(self) -> float: + return self.fwd[is_fp8_fnuz()] + + @property + def max_bwd(self) -> float: + return self.bwd[is_fp8_fnuz()] + +class _FormatMaxVals(Enum): + """ + Tuples of FP8 (OCP, FNUZ) values for different formats. + """ + E4M3 = (448, 240) + E5M2 = (57344, 57344) class Format(Enum): """ @@ -40,15 +60,10 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) - if te_rocm_build: - max_e4m3_val = 240 if is_fp8_fnuz() else 448 - E4M3 = _FormatHelper(max_fwd=max_e4m3_val, max_bwd=max_e4m3_val) - else: - E4M3 = _FormatHelper(max_fwd=448, max_bwd=448) - E5M2 = _FormatHelper(max_fwd=57344, max_bwd=57344) - HYBRID = _FormatHelper(max_fwd=E4M3.max_fwd, max_bwd=E5M2.max_bwd) + E4M3 = _FormatHelperFP8(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelperFP8(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) + HYBRID = _FormatHelperFP8(fwd=E4M3.fwd, bwd=E5M2.bwd) @dataclass(frozen=True) diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp index 050d0fd23..aa40a8e35 100644 --- a/transformer_engine/jax/csrc/extensions/amax.cpp +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -101,4 +101,4 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( } // namespace jax } // namespace transformer_engine -#endif // #ifndef USE_ROCM \ No newline at end of file +#endif // #ifndef USE_ROCM From 46d382db16b620e02c06c01e822d413f23ddd898 Mon Sep 17 00:00:00 2001 From: Alex Magro Date: Tue, 3 Mar 2026 14:59:50 -0600 Subject: [PATCH 141/141] add use_async_d2h_group_size as a test parameter --- tests/jax/test_custom_call_compute.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3b9ee0034..9303d6da8 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1786,11 +1786,14 @@ def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) - def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + @pytest_parametrize_wrapper("use_async_d2h_group_size", [True, False]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout, use_async_d2h_group_size): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) - if not is_hip_extension(): + if use_async_d2h_group_size: + if is_hip_extension(): + pytest.skip("ROCm does not support use_async_d2h_group_sizes yet.") num_gemms = input_shape[0] _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( group_sizes, @@ -1806,7 +1809,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): rhs, group_sizes, contracting_dims, - use_async_d2h_group_sizes=not is_hip_extension(), + use_async_d2h_group_sizes=use_async_d2h_group_size, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)