diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f40b28189..42c5f0342 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -83,7 +83,9 @@ jobs: options: --user root steps: - name: 'Dependencies' - run: pip install torch pybind11[global] einops onnxscript + run: | + pip install pybind11[global] einops onnxscript + pip install torch --index-url https://download.pytorch.org/whl/cu130 - name: 'Checkout' uses: actions/checkout@v3 with: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bbe486fac..5043d6ea2 100755 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -38,3 +38,9 @@ repos: entry: clang-format -i args: ["-style=file"] files: ^transformer_engine.*\.(c|cc|cxx|cpp|cu|cuh|h|hpp)$ + + - repo: https://github.com/netromdk/vermin + rev: c75aca72f4e85c6e47252139e8695f1c8b5f9ae3 + hooks: + - id: vermin + args: ['-t=3.10', '--violations'] diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index deda80e53..be6c079be 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 +Subproject commit be6c079be8aaffa0fc079fcf039887e637c289c7 diff --git a/README.rst b/README.rst index 66d1b0b3e..adf19aca1 100644 --- a/README.rst +++ b/README.rst @@ -362,6 +362,14 @@ Transformer Engine Latest News =========== + +* [09/2025] `Pretraining Large Language Models with NVFP4 `_ +* [09/2025] `Native FP8 Mixed Precision Training for Ling 2.0, Open Sourced! `_ +* [09/2025] `Faster Training Throughput in FP8 Precision with NVIDIA NeMo `_ +* [08/2025] `How we built DeepL's next-generation LLMs with FP8 for training and inference `_ +* [08/2025] `NVFP4 Trains with Precision of 16-bit and Speed and Efficiency of 4-bit `_ +* [06/2025] `Floating Point 8: An Introduction to Efficient, Lower-Precision AI Training `_ +* [05/2025] `Advanced Optimization Strategies for LLM Training on NVIDIA Grace Hopper `_ * [03/2025] `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 `_ * [03/2025] `Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking `_ @@ -436,7 +444,7 @@ PyTorch fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3) # Enable autocasting for the forward pass - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.autocast(enabled=True, recipe=fp8_recipe): out = model(inp) loss = out.sum() @@ -471,7 +479,7 @@ Flax fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID) # Enable autocasting for the forward pass - with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with te.autocast(enabled=True, recipe=fp8_recipe): model = te_flax.DenseGeneral(features=HIDDEN) def loss_fn(params, other_vars, inp): @@ -547,7 +555,7 @@ pip Installation **Prerequisites for pip installation:** * A compatible C++ compiler -* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed +* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) if installing from source. To install the latest stable version with pip: diff --git a/benchmarks/benchmark_rht_cast.py b/benchmarks/benchmark_rht_cast.py new file mode 100644 index 000000000..9c47856f7 --- /dev/null +++ b/benchmarks/benchmark_rht_cast.py @@ -0,0 +1,152 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import torch +import pandas as pd +import torch.utils.benchmark as benchmark + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +import transformer_engine.pytorch.cpp_extensions as ext + +from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + +scale_padding_to = 1 +permute_scale = False + +TORCH_TO_TE_FLOAT_MAP = { + torch.bfloat16: tex.DType.kBFloat16, +} + + +def run_kernel(shape, stochastic_rounding: bool, input_dtype=torch.bfloat16): + # Generate random input data + M, K = shape + x = torch.randn([M, K], dtype=input_dtype, device="cuda") + + assert shape[0] % 16 == 0, "Shape must be divisible by 16" + assert shape[1] % 16 == 0, "Shape must be divisible by 16" + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=True, + with_post_rht_amax=True, + with_random_sign_mask=True, + stochastic_rounding=stochastic_rounding, + ) + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, K), dtype=x.dtype, device=x.device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + with torch.no_grad(): + stmt = "kernel_func(input, output)" + globals_dict = { + "kernel_func": nvfp4_quantizer.update_quantized, + "input": x, + "output": x_nvfp4_sut, + } + + timing = benchmark.Timer( + stmt=stmt, + globals=globals_dict, + num_threads=1, + ).blocked_autorange(min_run_time=5) + print(timing) + timing_us = timing.median * 1e6 + + input_nbytes = shape[0] * shape[1] * 2 # bf16 + output_nbytes = shape[0] * shape[1] // 2 # //2 for fp4 + sf_nbytes = shape[0] * shape[1] // 16 # //16 for 1 byte per 16 elems + + total_nbytes = ( + 0 + + input_nbytes + * 3 # Reading input for Amax(x)&Amax(RHT(x.T)), Reading input for Cast(x), Reaindg input for Cast(RHT(x.T)) + + 2 * 4 # Output 2 * float for scale & amax + + 2 * 4 # Input 2 * float + + output_nbytes * 2 # Output from Cast(x) and Cast(RHT(x.T)) + + sf_nbytes * 2 # Scale factor + ) + + throughput_GBps = total_nbytes / (1024 * 1024 * 1024) / (timing_us / 1e6) + + print( + f"Stochastic rounding: {stochastic_rounding}, Total: {total_nbytes} bytes, Throughput:" + f" {throughput_GBps} GB/s" + ) + return timing_us, throughput_GBps + + +# Nsight Compute Profiling Command: +# ncu -f -o block_scaled_1d_cast_transpose_kernel --set=full --kernel-name "block_scaled_1d_cast_transpose_kernel" -s 5 -c 5 python benchmark_cast_transpose_1d_block.py --profile + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--profile", action="store_true", help="Enable profiling mode") + args = parser.parse_args() + + if args.profile: + print("Profiling is enabled.") + else: + print("Profiling is disabled.") + + shapes = [ + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ] + + if args.profile: + shapes = [ + (16384, 6144), + ] + + data = [] + for stochastic_rounding in [True]: # , False]: + for shape in shapes: + print( + f"Running benchmark_func with shape {shape} and stochastic_rounding" + f" {stochastic_rounding}" + ) + timing_us, throughput_GBps = run_kernel(shape, stochastic_rounding) + data.append( + [ + "benchmark_func", + shape, + stochastic_rounding, + timing_us, + throughput_GBps, + ] + ) + + df = pd.DataFrame( + data=data, + columns=[ + "kernel", + "shape", + "stochastic_rounding", + "timing_us", + "throughput(GB/s)", + ], + ) + print(df) + df.to_csv("benchmark_cast_nvfp4.csv", index=False) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 44f1c8967..d4bbad75c 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -6,58 +6,69 @@ import torch import torch.utils.benchmark as benchmark import pandas as pd -import pathlib from transformer_engine.pytorch.module import GroupedLinear -from transformer_engine.common.recipe import Float8BlockScaling, MXFP8BlockScaling -from transformer_engine.pytorch.fp8 import fp8_autocast, FP8GlobalStateManager +from transformer_engine.common.recipe import ( + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) +from transformer_engine.pytorch.quantization import autocast, FP8GlobalStateManager from contextlib import nullcontext """ # Profile BF16 recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_bf16 \ + --output=./benchmarks/linear/b200_numgemm_8_bf16 \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe bf16 # Profile FP8 sub-channel recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/h100hbm_mkn_4096_4096_4096_numgemm_8_fp8_sub_channel \ + --output=./benchmarks/linear/h100hbm_numgemm_8_fp8_sub_channel \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe fp8_sub_channel # Profile MXFP8 recipe with Nsight Systems nsys profile \ - --output=./benchmarks/linear/b200_mkn_4096_4096_4096_numgemm_8_mxfp8 \ + --output=./benchmarks/linear/b200_numgemm_8_mxfp8 \ --force-overwrite true \ --trace=cuda,nvtx,cudnn,cublas \ python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe mxfp8 +# Profile NVFP4 recipe with Nsight Systems +nsys profile \ + --output=./benchmarks/linear/b200_numgemm_8_nvfp4 \ + --force-overwrite true \ + --trace=cuda,nvtx,cudnn,cublas \ + python benchmarks/linear/benchmark_grouped_linear.py --profile --recipe nvfp4 + """ RECIPES = { "bf16": None, "fp8_sub_channel": Float8BlockScaling(), "mxfp8": MXFP8BlockScaling(), + "nvfp4": NVFP4BlockScaling(), } mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( FP8GlobalStateManager.is_fp8_block_scaling_available() ) +nvfp4_available, reason_for_no_nvfp4 = FP8GlobalStateManager.is_nvfp4_available() def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps=1, recipe=None): assert mode in ["fwd_only", "fwd_bwd"] - fp8_context = ( - fp8_autocast(enabled=True, fp8_recipe=recipe) if recipe is not None else nullcontext() + quantization_context = ( + autocast(enabled=True, recipe=recipe) if recipe is not None else nullcontext() ) - # print(f"fp8_context: {fp8_context} and is it nullcontext? {isinstance(fp8_context, nullcontext)}") if mode == "fwd_only": - with torch.no_grad(), fp8_context: + with torch.no_grad(), quantization_context: for i in range(run_num_steps): y_q = layer.forward( x, @@ -70,7 +81,7 @@ def run_linear_multiple_steps(layer, x, m_splits, mode, gradient, run_num_steps= layer.zero_grad() x.grad = None - with fp8_context: + with quantization_context: for i in range(run_num_steps): label = f"step_{i}" torch.cuda.nvtx.range_push(label) @@ -145,7 +156,7 @@ def benchmark_linear( "recipe": recipe, }, num_threads=1, - ).blocked_autorange(min_run_time=5) + ).blocked_autorange(min_run_time=10) print(f"{recipe_name}: {timing} \n") timing_ms = timing.median * 1000 / num_microbatches @@ -228,30 +239,44 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): use_bias = False # Set the MKN values to benchmark + # Deepseek V3 EP64, SEQ_LEN=8192, topK8 + # 256 expert => 4 local experts + # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 16384 + # M = AvgM * localExperts = 65536 + # K = 7168 + # N = 2048 + + # Deepseek V3 EP32, SEQ_LEN=8192, topK8 + # 256 expert => 8 local experts + # Avg M per expert: AvgM = SEQ_LEN * topK / localExperts = 8192 + # M = AvgM * localExperts = 65536 + # K = 7168 + # N = 2048 + + # 4 or 8local experts per rank + num_gemms_list = [4, 8] + + # MKN for group linear mkns = [] - for m in [8192]: - # for m in [4096, 8192, 16384]: - # for n in [1024, 2048, 4096, 8192, 16384]: - for n in [8192]: - for k in [4096]: + for m in [65536]: + for k in [7168]: + for n in [2048]: mkns.append((m, k, n)) # default recipes to run if not specified recipe_list = ["bf16"] if args.recipe == "all": - recipe_list = ["bf16", "fp8_sub_channel", "mxfp8"] + recipe_list = ["bf16", "fp8_sub_channel", "mxfp8", "nvfp4"] else: recipe_list = [args.recipe] - num_gemms_list = [8] - if args.profile: - mkns = [(4096 * 8, 4096, 4096)] + mkns = [(8192 * 8, 7168, 2048)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" - " fp8_sub_channel, mxfp8, or bf16" + " fp8_sub_channel, mxfp8, nvfp4, or bf16" ) recipe_list = [args.recipe] num_gemms_list = [8] @@ -268,13 +293,17 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): "bf16", "fp8_sub_channel", "mxfp8", - ], "Recipe must be one of bf16, fp8_sub_channel, or mxfp8" + "nvfp4", + ], "Recipe must be one of bf16, fp8_sub_channel, mxfp8, or nvfp4" if recipe_name == "mxfp8" and not mxfp8_available: print(f"MXFP8 is not available, skipping {recipe_name}") continue if recipe_name == "fp8_sub_channel" and not fp8_block_scaling_available: print(f"FP8 block scaling is not available, skipping {recipe_name}") continue + if recipe_name == "nvfp4" and not nvfp4_available: + print(f"NVFP4 is not available, skipping {recipe_name}") + continue df = run_benchmark_linear( mkns, diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 81006d78c..c7f2fd9b8 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.8.0.dev0 +2.10.0.dev0 diff --git a/build_tools/build_ext.py b/build_tools/build_ext.py index 8bcfc5a69..1e6651b2f 100644 --- a/build_tools/build_ext.py +++ b/build_tools/build_ext.py @@ -61,6 +61,7 @@ def _build_cmake(self, build_dir: Path, install_dir: Path) -> None: build_dir, f"-DPython_EXECUTABLE={sys.executable}", f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}", + f"-DPython_SITEARCH={sysconfig.get_path('platlib')}", f"-DCMAKE_BUILD_TYPE={build_type}", f"-DCMAKE_INSTALL_PREFIX={install_dir}", ] diff --git a/build_tools/jax.py b/build_tools/jax.py index 182940c11..6f2e57f87 100644 --- a/build_tools/jax.py +++ b/build_tools/jax.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -105,6 +105,7 @@ def setup_jax_extension( sources=[str(path) for path in sources], include_dirs=[str(path) for path in include_dirs], extra_compile_args=cxx_flags, + libraries=["nccl"] if not rocm_build() else [], ) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index bb084293f..4bbefd730 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -27,12 +27,12 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] + return ["torch>=2.1", "einops", "onnxscript", "onnx"] def test_requirements() -> List[str]: """Test dependencies for TE/JAX extensions.""" - return ["numpy", "torchvision", "transformers"] + return ["numpy", "torchvision", "transformers", "torchao==0.13"] def setup_pytorch_extension( diff --git a/build_tools/utils.py b/build_tools/utils.py index e3c5b6be8..0c18d7ecf 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -15,12 +15,31 @@ import shutil import subprocess import sys +import platform from pathlib import Path from importlib.metadata import version as get_version from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union +# Needs to stay consistent with .pre-commit-config.yaml config. +def min_python_version() -> Tuple[int]: + """Minimum supported Python version.""" + return (3, 10, 0) + + +def min_python_version_str() -> str: + """String representing minimum supported Python version.""" + return ".".join(map(str, min_python_version())) + + +if sys.version_info < min_python_version(): + raise RuntimeError( + f"Transformer Engine requires Python {min_python_version_str()} or newer, " + f"but found Python {platform.python_version()}." + ) + + @functools.lru_cache(maxsize=None) def debug_build_enabled() -> bool: """Whether to build with a debug configuration""" @@ -305,15 +324,16 @@ def get_cuda_include_dirs() -> Tuple[str, str]: @functools.lru_cache(maxsize=None) def cuda_archs() -> str: - version = cuda_version() - if os.getenv("NVTE_CUDA_ARCHS") is None: + archs = os.getenv("NVTE_CUDA_ARCHS") + if archs is None: + version = cuda_version() if version >= (13, 0): - os.environ["NVTE_CUDA_ARCHS"] = "75;80;89;90;100;120" + archs = "75;80;89;90;100;120" elif version >= (12, 8): - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90;100;120" + archs = "70;80;89;90;100;120" else: - os.environ["NVTE_CUDA_ARCHS"] = "70;80;89;90" - return os.getenv("NVTE_CUDA_ARCHS") + archs = "70;80;89;90" + return archs def cuda_version() -> Tuple[int, ...]: diff --git a/build_tools/wheel_utils/Dockerfile.aarch b/build_tools/wheel_utils/Dockerfile.aarch index 223c4a7f1..404cb941c 100644 --- a/build_tools/wheel_utils/Dockerfile.aarch +++ b/build_tools/wheel_utils/Dockerfile.aarch @@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_aarch64 WORKDIR /TransformerEngine/ COPY ../.. /TransformerEngine/ -ARG VER="12-3" -ARG ARCH="aarch64" -RUN dnf -y install vim +ARG CUDA_MAJOR="12" +ARG CUDA_MINOR="3" + +# Args for build_wheels.sh +ARG BUILD_METAPACKAGE=true +ARG BUILD_COMMON=true +ARG BUILD_PYTORCH=true +ARG BUILD_JAX=true +ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE} +ENV BUILD_COMMON=${BUILD_COMMON} +ENV BUILD_PYTORCH=${BUILD_PYTORCH} +ENV BUILD_JAX=${BUILD_JAX} +ENV CUDA_MAJOR=${CUDA_MAJOR} # Cuda toolkit, cudnn, driver. RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo RUN dnf -y install epel-release -RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ - cuda-libraries-${VER}.${ARCH} \ - cuda-libraries-devel-${VER}.${ARCH} -RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \ + cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \ + cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 +RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR} RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit +RUN dnf -y install cuda-toolkit-${CUDA_MAJOR} RUN dnf clean all RUN dnf -y install glog.aarch64 glog-devel.aarch64 +RUN dnf -y install libnccl libnccl-devel libnccl-static ENV PATH="/usr/local/cuda/bin:${PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" @@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] +CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_aarch64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"] diff --git a/build_tools/wheel_utils/Dockerfile.x86 b/build_tools/wheel_utils/Dockerfile.x86 index 26122eed9..daa7f961c 100644 --- a/build_tools/wheel_utils/Dockerfile.x86 +++ b/build_tools/wheel_utils/Dockerfile.x86 @@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 WORKDIR /TransformerEngine/ COPY ../.. /TransformerEngine/ -ARG VER="12-3" -ARG ARCH="x86_64" -RUN dnf -y install vim +ARG CUDA_MAJOR="12" +ARG CUDA_MINOR="3" + +# Args for build_wheels.sh +ARG BUILD_METAPACKAGE=true +ARG BUILD_COMMON=true +ARG BUILD_PYTORCH=true +ARG BUILD_JAX=true +ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE} +ENV BUILD_COMMON=${BUILD_COMMON} +ENV BUILD_PYTORCH=${BUILD_PYTORCH} +ENV BUILD_JAX=${BUILD_JAX} +ENV CUDA_MAJOR=${CUDA_MAJOR} # Cuda toolkit, cudnn, driver. RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo RUN dnf -y install epel-release -RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ - cuda-libraries-${VER}.${ARCH} \ - cuda-libraries-devel-${VER}.${ARCH} -RUN dnf -y install --allowerasing cudnn9-cuda-12 +RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \ + cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \ + cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 +RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR} RUN dnf clean all RUN rm -rf /var/cache/dnf/* RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf -RUN dnf -y install cuda-toolkit +RUN dnf -y install cuda-toolkit-${CUDA_MAJOR} RUN dnf clean all RUN dnf -y install glog.x86_64 glog-devel.x86_64 +RUN dnf -y install libnccl libnccl-devel libnccl-static ENV PATH="/usr/local/cuda/bin:${PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" @@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ENV CUDADIR=/usr/local/cuda ENV NVTE_RELEASE_BUILD=1 -CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] +CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_x86_64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"] \ No newline at end of file diff --git a/build_tools/wheel_utils/build_wheels.sh b/build_tools/wheel_utils/build_wheels.sh index 4a6653479..0e8ab68a8 100644 --- a/build_tools/wheel_utils/build_wheels.sh +++ b/build_tools/wheel_utils/build_wheels.sh @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,8 +11,10 @@ BUILD_METAPACKAGE=${2:-true} BUILD_COMMON=${3:-true} BUILD_PYTORCH=${4:-true} BUILD_JAX=${5:-true} +CUDA_MAJOR=${6:-12} export NVTE_RELEASE_BUILD=1 +export PIP_CONSTRAINT="" export TARGET_BRANCH=${TARGET_BRANCH:-} mkdir -p /wheelhouse/logs @@ -41,14 +43,9 @@ else fi if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install setuptools wheel -fi - -# Install deps -if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install pybind11[global] ninja + ${PYBINDIR}pip install pybind11[global] ninja setuptools wheel else - ${PYBINDIR}pip install cmake pybind11[global] ninja + /opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel fi if $BUILD_METAPACKAGE ; then @@ -77,13 +74,23 @@ if $BUILD_COMMON ; then # Create the wheel. ${PYBINDIR}python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt - # Repack the wheel for cuda specific package, i.e. cu12. - ${PYBINDIR}wheel unpack dist/* - # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). - sed -i "s/Name: transformer-engine/Name: transformer-engine-${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - sed -i "s/Name: transformer_engine/Name: transformer_engine_${TE_CUDA_VERS}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" - mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_${TE_CUDA_VERS}-${VERSION}.dist-info" - ${PYBINDIR}wheel pack ${WHL_BASE} + if [ "$ROCM_BUILD" = "1" ]; then + # Repack the wheel for specific rocm package. + ${PYBINDIR}wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-rocm/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_rocm/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_rocm-${VERSION}.dist-info" + ${PYBINDIR}wheel pack ${WHL_BASE} + else + # Repack the wheel for specific cuda version. + /opt/python/cp310-cp310/bin/wheel unpack dist/* + # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). + sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" + mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info" + /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} + fi # Rename the wheel to make it python version agnostic. whl_name=$(basename dist/*) @@ -94,25 +101,25 @@ if $BUILD_COMMON ; then fi if $BUILD_PYTORCH ; then - cd /TransformerEngine/transformer_engine/pytorch - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 - else - PYBINDIR=/opt/python/cp38-cp38/bin/ - ${PYBINDIR}pip install torch - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt - cp dist/* /wheelhouse/ + cd /TransformerEngine/transformer_engine/pytorch + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install torch --index-url https://download.pytorch.org/whl/rocm6.3 + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install torch + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt + cp dist/* /wheelhouse/ fi if $BUILD_JAX ; then - cd /TransformerEngine/transformer_engine/jax - if [ "$ROCM_BUILD" = "1" ]; then - ${PYBINDIR}pip install jax - else - PYBINDIR=/opt/python/cp310-cp310/bin/ - ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib - fi - ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt - cp dist/* /wheelhouse/ + cd /TransformerEngine/transformer_engine/jax + if [ "$ROCM_BUILD" = "1" ]; then + ${PYBINDIR}pip install jax + else + PYBINDIR=/opt/python/cp310-cp310/bin/ + ${PYBINDIR}pip install "jax[cuda12_local]" jaxlib + fi + ${PYBINDIR}python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt + cp dist/* /wheelhouse/ fi diff --git a/build_tools/wheel_utils/launch_aarch.sh b/build_tools/wheel_utils/launch_aarch.sh index 04e3cd691..85f754ca1 100644 --- a/build_tools/wheel_utils/launch_aarch.sh +++ b/build_tools/wheel_utils/launch_aarch.sh @@ -2,7 +2,29 @@ # # See LICENSE for license information. -docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . +# Remove leftovers. +rm -rf aarch_wheelhouse_cu12 aarch_wheelhouse_cu13 + +# CUDA 12. +docker build --no-cache \ + --build-arg CUDA_MAJOR=12 \ + --build-arg CUDA_MINOR=3 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . +docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" +docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu12 + +# CUDA 13. +docker build --no-cache \ + --build-arg CUDA_MAJOR=13 \ + --build-arg CUDA_MINOR=0 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" -rm -rf aarch_wheelhouse -docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse +docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu13 diff --git a/build_tools/wheel_utils/launch_x86.sh b/build_tools/wheel_utils/launch_x86.sh index b0d20be3f..11fc52294 100644 --- a/build_tools/wheel_utils/launch_x86.sh +++ b/build_tools/wheel_utils/launch_x86.sh @@ -2,7 +2,29 @@ # # See LICENSE for license information. -docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . +# Remove leftovers. +rm -rf x86_wheelhouse_cu12 x86_wheelhouse_cu13 + +# CUDA 12. +docker build --no-cache \ + --build-arg CUDA_MAJOR=12 \ + --build-arg CUDA_MINOR=3 \ + --build-arg BUILD_METAPACKAGE=true \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=true \ + --build-arg BUILD_JAX=true \ + -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . +docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" +docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu12 + +# CUDA 13. +docker build --no-cache \ + --build-arg CUDA_MAJOR=13 \ + --build-arg CUDA_MINOR=0 \ + --build-arg BUILD_METAPACKAGE=false \ + --build-arg BUILD_COMMON=true \ + --build-arg BUILD_PYTORCH=false \ + --build-arg BUILD_JAX=false \ + -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" -rm -rf x86_wheelhouse -docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse +docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu13 diff --git a/ci/jax.sh b/ci/jax.sh index 81d994585..ef9dbe124 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -59,8 +59,7 @@ run_test_config() { run_default_fa 1 test_functions.py run 1 test_fused_attn.py NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass - run_default_fa 1 test_helper.py - run_default_fa 1 test_layer.py #it effectevly always uses unfused attention + run_default_fa 1 test_layer.py # it effectively always uses unfused attention run_default_fa 1 test_sanity_import.py run_default_fa 1 test_softmax.py } @@ -71,7 +70,7 @@ run_test_config_mgpu() { # Mitigate distributed tests hang by adding 5min timeout _timeout_args="--timeout 300 --timeout-method thread" - # Workaround for some distributed tests hang/abotrion + # Workaround for some distributed tests hang/abortion export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then @@ -81,8 +80,10 @@ run_test_config_mgpu() { _dfa_level=3 export NVTE_JAX_UNITTEST_LEVEL=L2 fi + + run_default_fa 2 test_distributed_dense.py # Do not fail automated CI if test_distributed_fused_attn is hung - # If the sctipt run w/o TEST_LEVEL the test error will be honored + # If the script runs w/o TEST_LEVEL the test error will be honored if [ "$TEST_LEVEL" -le 3 ]; then TEST_ERROR_IGNORE="1" fi @@ -95,7 +96,7 @@ run_test_config_mgpu() { run_default_fa 3 test_sanity_import.py } -# Single config mode, run it synchroniously and return result +# Single config mode, run it synchronously and return result if [ -n "$SINGLE_CONFIG" ]; then _fus_attn="$SINGLE_CONFIG" configure_fused_attn_env $_fus_attn && run_test_config diff --git a/ci/pytorch.sh b/ci/pytorch.sh index be150485f..5558beca3 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -51,7 +51,8 @@ run_test_config(){ run_default_fa 1 test_deferred_init.py run_default_fa 1 test_float8tensor.py run_default_fa 1 test_float8_current_scaling_exact.py - test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 run 1 test_cpu_offloading.py + run 1 test_cpu_offloading.py + test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 run 3 test_cpu_offloading_v1.py run_default_fa 1 test_fused_rope.py run_default_fa 1 test_fused_router.py run_default_fa 1 test_fusible_ops.py diff --git a/docs/api/common.rst b/docs/api/common.rst index 541118985..3edd7cae2 100644 --- a/docs/api/common.rst +++ b/docs/api/common.rst @@ -12,6 +12,8 @@ Common API .. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3) +.. autoapiclass:: transformer_engine.common.recipe.NVFP4BlockScaling(fp4_format=Format.E2M1) + .. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID) .. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3) diff --git a/docs/api/jax.rst b/docs/api/jax.rst index 1af5cd1d0..789b27e59 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -30,6 +30,7 @@ Modules .. autoapifunction:: transformer_engine.jax.fp8_autocast +.. autoapifunction:: transformer_engine.jax.autocast .. autoapifunction:: transformer_engine.jax.update_collections diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index fcfa20cbd..fe726d851 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -41,8 +41,28 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.fp8_model_init +.. autoapifunction:: transformer_engine.pytorch.autocast + +.. autoapifunction:: transformer_engine.pytorch.quantized_model_init + .. autoapifunction:: transformer_engine.pytorch.checkpoint +.. autoapifunction:: transformer_engine.pytorch.is_fp8_available + +.. autoapifunction:: transformer_engine.pytorch.is_mxfp8_available + +.. autoapifunction:: transformer_engine.pytorch.is_fp8_block_scaling_available + +.. autoapifunction:: transformer_engine.pytorch.is_nvfp4_available + +.. autoapifunction:: transformer_engine.pytorch.is_bf16_available + +.. autoapifunction:: transformer_engine.pytorch.get_cudnn_version + +.. autoapifunction:: transformer_engine.pytorch.get_device_compute_capability + +.. autoapifunction:: transformer_engine.pytorch.get_default_recipe + .. autoapifunction:: transformer_engine.pytorch.make_graphed_callables .. autoapifunction:: transformer_engine.pytorch.get_cpu_offload_context @@ -65,4 +85,4 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index .. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode - :members: FP8, NONE \ No newline at end of file + :members: FP8, NONE diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index a7b86dad3..906c62556 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -69,7 +69,7 @@ Let's look at a simple example of training a Transformer layer using Transformer for epoch in range(5): transformer_layer.train() optimizer.zero_grad() - with te.fp8_autocast(enabled=True): + with te.autocast(enabled=True): output = transformer_layer(dummy_input) loss = criterion(output, dummy_target) loss.backward() diff --git a/docs/examples/FP4_format.png b/docs/examples/FP4_format.png new file mode 100644 index 000000000..8c54c3379 Binary files /dev/null and b/docs/examples/FP4_format.png differ diff --git a/docs/examples/FP4_linear.png b/docs/examples/FP4_linear.png new file mode 100644 index 000000000..2cd4511ad Binary files /dev/null and b/docs/examples/FP4_linear.png differ diff --git a/docs/examples/advanced_optimizations.ipynb b/docs/examples/advanced_optimizations.ipynb index 3d889859b..5dc9cb92f 100644 --- a/docs/examples/advanced_optimizations.ipynb +++ b/docs/examples/advanced_optimizations.ipynb @@ -71,7 +71,7 @@ " amax_compute_algo=\"max\",\n", ")\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = basic_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "\n", @@ -81,7 +81,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] }, @@ -135,7 +135,7 @@ "\n", "Enabling data parallelism with Transformer Engine is similar to enabling data parallelism with standard PyTorch models: simply wrap the modules with [torch.nn.parallel.DistributedDataParallel](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html). Transformer Engine modules also have native support for tensor and sequence parallelism. If the user provides a process group for tensor parallelism, the modules will distribute the data and perform communication internally. If sequence parallelism is enabled, it will be applied for operations that are not amenable to tensor parallelism and it will use the tensor-parallel process group.\n", "\n", - "One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager." + "One important consideration for multi-GPU FP8 training is how to synchronize the FP8 scaling factors between GPUs. If tensor parallelism is enabled, the scales must be synchronized over the tensor-parallel group. However, synchronizing over both the data-parallel and tensor-parallel groups is recommended for the best convergence. This can be configured with the **fp8_group** argument in the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager." ] }, { @@ -169,7 +169,7 @@ ")\n", "\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=world_group):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=world_group):\n", " y = parallel_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "\n", @@ -179,10 +179,10 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = {\n", + " autocast_kwargs = {\n", " \"enabled\": True,\n", - " \"fp8_recipe\": fp8_recipe,\n", - " \"fp8_group\": world_group,\n", + " \"recipe\": fp8_recipe,\n", + " \"amax_reduction_group\": world_group,\n", " },\n", ")" ] @@ -234,7 +234,7 @@ " param.main_grad = torch.zeros_like(param, dtype=torch.float32)\n", "\n", "# Training step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = wgrad_transformer(x, attention_mask=None)\n", "y.backward(dy)\n", "for param in wgrad_transformer.parameters():\n", @@ -248,7 +248,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] }, @@ -268,7 +268,7 @@ "\n", "\n", "\n", - "Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n", + "Since weights are typically trained in FP32, a type conversion is required before we can perform compute in FP8. By default, the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager will handle this internally by casting non-FP8 tensors to FP8 as they are encountered. However, we can improve upon this in some cases. In particular, if our training iteration is split into multiple gradient accumulation steps, each micro-batch will encounter the same weight tensors. Thus, we only need to cast the weights to FP8 in the first gradient accumulation step and we can cache the resulting FP8 weights for the remaining gradient accumulation steps.\n", "\n", "
\n", "\n", @@ -303,12 +303,12 @@ "weight_caching_transformer.to(dtype=dtype).cuda()\n", "\n", "# Cast weights in first gradient accumulation step\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=True)\n", "y.backward(dy)\n", "\n", "# Reuse FP8 weights in subsequent gradient accumulation steps\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " y = weight_caching_transformer(x, attention_mask=None, is_first_microbatch=False)\n", "y.backward(dy)\n", "\n", @@ -318,7 +318,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None, \"is_first_microbatch\": False },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] } diff --git a/docs/examples/fp8_primer.ipynb b/docs/examples/fp8_primer.ipynb index 788d6c37a..457d13921 100644 --- a/docs/examples/fp8_primer.ipynb +++ b/docs/examples/fp8_primer.ipynb @@ -5,9 +5,9 @@ "id": "7b3e6954", "metadata": {}, "source": [ - "# Using FP8 with Transformer Engine\n", + "# Using FP8 and FP4 with Transformer Engine\n", "\n", - "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. In this example we will introduce the FP8 datatype and show how to use it with Transformer Engine.\n", + "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. Blackwell added support for NVFP4 and MXFP8 datatypes. In this example we will introduce these low precision datatypes and show how to use them with Transformer Engine.\n", "\n", "## Introduction to FP8\n", "\n", @@ -100,19 +100,66 @@ "" ] }, + { + "cell_type": "markdown", + "id": "fd7b4f37-50a2-4d41-9067-cf0c471cb2d7", + "metadata": {}, + "source": [ + "## Beyond FP8 - training with NVFP4\n", + "\n", + "In addition to MXFP8, NVIDIA Blackwell introduced support for an even smaller, 4-bit format called NVFP4. The values are represented there in E2M1 format, able to represent values of magnitude up to +/-6.\n", + "\n", + "
\n", + "\n", + "
Figure 8: FP4 E2M1 format can represent values between +/-6.
\n", + "
\n", + "\n", + "### NVFP4 Format\n", + "\n", + "NVFP4 format is similar to MXFP8 - it also uses granular scaling to preserve the dynamic range. The differences are:\n", + "\n", + " - Granularity of the scaling factors: in NVFP4 format a single scaling factor is used per block of 16 elements, whereas MXFP8 uses 1 scaling factor per block of 32 elements\n", + " - Datatype of the scaling factors: NVFP4 uses FP8 E4M3 as the scaling factor per block, whereas MXFP8 uses E8M0 as the scaling factor datatype. Choice of E4M3 for the scaling factor enables preservation of more information about mantissa, but does not enable the full dynamic range of FP32. Therefore, NVFP4 uses an additional single per-tensor FP32 scaling factor to avoid overflows.\n", + "\n", + "In the NVFP4 training recipe for weight tensors we use a different variant of the NVFP4 quantization, where a single scaling factor is shared by a 2D block of 16x16 elements. This is similar to the weight quantization scheme employed in [DeepSeek-v3 training](https://arxiv.org/abs/2412.19437v1), but with a much finer granularity.\n", + "\n", + "### NVFP4 training recipe\n", + "\n", + "The NVFP4 training recipe implemented in Transformer Engine is described in [Pretraining Large Language Models with NVFP4](https://arxiv.org/abs/2509.25149v1) paper. The main elements of the recipe are:\n", + "\n", + " - Stochastic Rounding. When quantizing gradients to NVFP4, we use stochastic rounding to avoid the bias introduced by quantization. With stochastic rounding values are rounded probabilistically to one of their two nearest representable numbers, with probabilities inversely\n", + "proportional to their distances.\n", + " - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n", + " - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n", + "disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n", + " - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n", + "\n", + "The full linear layer utilizing NVFP4 is presented in Figure 9.\n", + "\n", + "
\n", + "\n", + "
Figure 9: Linear layer utilizing NVFP4
\n", + "
" + ] + }, { "cell_type": "markdown", "id": "cf5e0b0d", "metadata": {}, "source": [ - "## Using FP8 with Transformer Engine\n", + "## Using FP8 and FP4 with Transformer Engine\n", "\n", - "Transformer Engine library provides tools enabling easy to use training with FP8 datatype using FP8 delayed scaling and MXFP8 strategies.\n", + "Transformer Engine library provides tools enabling easy to use training with FP8 and FP4 datatypes using different strategies.\n", "\n", "### FP8 recipe\n", "\n", - "The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe from the `transformer_engine.common.recipe` module stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", - "Similarly, [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) from the same module may be used to enable MXFP8 training." + "Transformer Engine defines a range of different low precision recipes to choose from in the `transformer_engine.common.recipe` module.\n", + "\n", + " - The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n", + " - [Float8CurrentScaling](../api/common.rst#transformer_engine.common.recipe.Float8CurrentScaling) recipe enables current per-tensor scaling with FP8.\n", + " - [Float8BlockScaling](../api/common.rst#transformer_engine.common.recipe.Float8BlockScaling) recipe enables block scaling with FP8 as described in [DeepSeek-v3 paper](https://arxiv.org/abs/2412.19437v1).\n", + " - [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) recipe enables MXFP8 training.\n", + " - [NVFP4BlockScaling](../api/common.rst#transformer_engine.common.recipe.NVFP4BlockScaling) recipe enables NVFP4 training." ] }, { @@ -122,12 +169,13 @@ "metadata": {}, "outputs": [], "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling\n", + "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling\n", "\n", "fp8_format = Format.HYBRID # E4M3 during forward pass, E5M2 during backward pass\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "mxfp8_format = Format.E4M3 # E4M3 used everywhere\n", - "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)" + "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)\n", + "nvfp4_recipe = NVFP4BlockScaling()" ] }, { @@ -135,7 +183,7 @@ "id": "f9591eb5", "metadata": {}, "source": [ - "This recipe is then used to configure the FP8 training." + "This recipe is then used to configure the low precision training." ] }, { @@ -145,7 +193,7 @@ "source": [ "### FP8 autocasting\n", "\n", - "Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager." + "Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager." ] }, { @@ -164,7 +212,7 @@ "\n", "inp = torch.rand((1024, 768)).cuda()\n", "\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, recipe=fp8_recipe):\n", " out_fp8 = my_linear(inp)" ] }, @@ -173,7 +221,7 @@ "id": "e41161f1", "metadata": {}, "source": [ - "The `fp8_autocast` context manager hides the complexity of handling FP8:\n", + "The `autocast` context manager hides the complexity of handling FP8:\n", "\n", "- All FP8-safe operations have their inputs cast to FP8\n", "- Amax history is updated\n", @@ -195,9 +243,9 @@ "source": [ "### Handling backward pass\n", "\n", - "When a model is run inside the `fp8_autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `fp8_autocast` context manager aggregates the tensors before performing the communication.\n", + "When a model is run inside the `autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `autocast` context manager aggregates the tensors before performing the communication.\n", "\n", - "Due to this aggregation the backward call needs to happen outside of the `fp8_autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass." + "Due to this aggregation the backward call needs to happen outside of the `autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass." ] }, { @@ -209,11 +257,11 @@ "source": [ "loss_fp8 = out_fp8.mean()\n", "\n", - "loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast\n", + "loss_fp8.backward() # This backward pass uses FP8, since out_fp8 was calculated inside autocast\n", "\n", "out_fp32 = my_linear(inp)\n", "loss_fp32 = out_fp32.mean()\n", - "loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside fp8_autocast" + "loss_fp32.backward() # This backward pass does not use FP8, since out_fp32 was calculated outside autocast" ] }, { @@ -235,13 +283,13 @@ { "data": { "text/plain": [ - "tensor([[ 0.2276, 0.2627, 0.3001, ..., 0.0346, 0.2211, 0.1188],\n", - " [-0.0963, -0.3725, 0.1717, ..., 0.0901, 0.0522, -0.3472],\n", - " [ 0.4526, 0.3482, 0.5976, ..., -0.0687, -0.0382, 0.1566],\n", + "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n", + " [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n", + " [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n", " ...,\n", - " [ 0.1698, 0.6061, 0.0385, ..., -0.2875, -0.1152, -0.0260],\n", - " [ 0.0679, 0.2946, 0.2751, ..., -0.2284, 0.0517, -0.1441],\n", - " [ 0.1865, 0.2353, 0.9172, ..., 0.1085, 0.1135, 0.1438]],\n", + " [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n", + " [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n", + " [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)" ] }, @@ -263,13 +311,13 @@ { "data": { "text/plain": [ - "tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.0233, 0.2498, 0.1131],\n", - " [-0.0767, -0.3778, 0.1862, ..., 0.0858, 0.0676, -0.3369],\n", - " [ 0.4615, 0.3593, 0.5813, ..., -0.0779, -0.0349, 0.1422],\n", + "tensor([[ 0.2373, 0.2674, 0.2980, ..., 0.1134, -0.3661, 0.1650],\n", + " [-0.0767, -0.3778, 0.1862, ..., -0.1370, -0.8448, -0.1770],\n", + " [ 0.4615, 0.3593, 0.5813, ..., 0.1696, -0.8826, -0.1826],\n", " ...,\n", - " [ 0.1914, 0.6038, 0.0382, ..., -0.2847, -0.0991, -0.0423],\n", - " [ 0.0864, 0.2895, 0.2719, ..., -0.2388, 0.0772, -0.1541],\n", - " [ 0.2019, 0.2275, 0.9027, ..., 0.1022, 0.1300, 0.1444]],\n", + " [ 0.1914, 0.6038, 0.0382, ..., 0.4049, -0.4729, 0.0118],\n", + " [ 0.0864, 0.2895, 0.2719, ..., -0.3337, -0.4922, 0.1240],\n", + " [ 0.2019, 0.2275, 0.9027, ..., 0.0706, -0.5481, 0.1356]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)" ] }, @@ -300,13 +348,13 @@ "name": "stdout", "output_type": "stream", "text": [ - "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.0346, 0.2211, 0.1188],\n", - " [-0.0963, -0.3724, 0.1717, ..., 0.0901, 0.0522, -0.3470],\n", - " [ 0.4526, 0.3479, 0.5976, ..., -0.0686, -0.0382, 0.1566],\n", + "tensor([[ 0.2276, 0.2629, 0.3000, ..., 0.1297, -0.3702, 0.1807],\n", + " [-0.0963, -0.3724, 0.1717, ..., -0.1250, -0.8501, -0.1669],\n", + " [ 0.4526, 0.3479, 0.5976, ..., 0.1685, -0.8864, -0.1977],\n", " ...,\n", - " [ 0.1698, 0.6062, 0.0385, ..., -0.2876, -0.1152, -0.0260],\n", - " [ 0.0679, 0.2947, 0.2750, ..., -0.2284, 0.0516, -0.1441],\n", - " [ 0.1865, 0.2353, 0.9170, ..., 0.1085, 0.1135, 0.1438]],\n", + " [ 0.1698, 0.6062, 0.0385, ..., 0.4038, -0.4564, 0.0143],\n", + " [ 0.0679, 0.2947, 0.2750, ..., -0.3271, -0.4990, 0.1198],\n", + " [ 0.1865, 0.2353, 0.9170, ..., 0.0673, -0.5567, 0.1246]],\n", " device='cuda:0', grad_fn=<_LinearBackward>)\n" ] } @@ -339,19 +387,14 @@ { "data": { "text/plain": [ - "tensor([[ 4.9591e-05, -1.9073e-04, 9.5367e-05, ..., -3.8147e-06,\n", - " 4.1962e-05, 2.2888e-05],\n", - " [ 2.2888e-05, -3.4332e-05, 2.2888e-05, ..., 2.6703e-05,\n", - " 5.3406e-05, -1.4114e-04],\n", - " [-3.8147e-05, 2.6703e-04, -3.8147e-06, ..., -5.7220e-05,\n", - " 4.1962e-05, -1.9073e-05],\n", + "tensor([[0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", " ...,\n", - " [ 1.1444e-05, -7.2479e-05, -3.8147e-06, ..., 5.3406e-05,\n", - " -1.5259e-05, 2.2888e-05],\n", - " [ 4.9591e-05, -9.5367e-05, 6.8665e-05, ..., -1.5259e-05,\n", - " 7.6294e-05, 4.5776e-05],\n", - " [-1.5259e-05, -7.6294e-06, 1.8692e-04, ..., -3.0518e-05,\n", - " -4.5776e-05, 7.6294e-06]], device='cuda:0', grad_fn=)" + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.],\n", + " [0., 0., 0., ..., 0., 0., 0.]], device='cuda:0',\n", + " grad_fn=)" ] }, "execution_count": 7, @@ -370,6 +413,53 @@ "source": [ "The differences in result coming from FP8 execution do not matter during the training process, but it is good to understand them, e.g. during debugging the model." ] + }, + { + "cell_type": "markdown", + "id": "d45e8b6c-803b-4a4f-8835-c19b0a94bc6a", + "metadata": {}, + "source": [ + "### Using multiple recipes in the same training run\n", + "\n", + "Sometimes it is desirable to use multiple recipes in the same training run. An example of this is the NVFP4 training, where a few layers at the end of the training should be run in higher precision. This can be achieved by using multiple autocasts, either completely separately or in a nested way (this could be useful when e.g. we want to have a configurable overarching recipe but still hardcode a different recipe for some pieces of the network)." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c663f694-41d6-47c0-a397-5fc56e692542", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[ 0.0547, 0.0039, -0.0664, ..., -0.2061, 0.2344, -0.3223],\n", + " [ 0.0131, -0.1436, 0.0168, ..., -0.4258, 0.1562, -0.0371],\n", + " [ 0.1074, -0.2773, 0.0576, ..., -0.2070, 0.0640, -0.1611],\n", + " ...,\n", + " [ 0.0825, -0.0630, 0.0571, ..., -0.3711, 0.1562, -0.4062],\n", + " [-0.1729, -0.1138, -0.0620, ..., -0.4238, 0.0703, -0.2070],\n", + " [-0.0908, -0.2148, 0.2676, ..., -0.4551, 0.1836, -0.4551]],\n", + " device='cuda:0', dtype=torch.bfloat16, grad_fn=<_LinearBackward>)\n" + ] + } + ], + "source": [ + "my_linear1 = te.Linear(768, 768).bfloat16() # The first linear - we want to run it in FP4\n", + "my_linear2 = te.Linear(768, 768).bfloat16() # The second linear - we want to run it in MXFP8\n", + "\n", + "inp = inp.bfloat16()\n", + "\n", + "with te.autocast(recipe=nvfp4_recipe):\n", + " y = my_linear1(inp)\n", + " with te.autocast(recipe=mxfp8_recipe):\n", + " out = my_linear2(y)\n", + "\n", + "print(out)\n", + "\n", + "out.mean().backward()" + ] } ], "metadata": { diff --git a/docs/examples/onnx/onnx_export.ipynb b/docs/examples/onnx/onnx_export.ipynb index 26ac71188..5ffc91854 100644 --- a/docs/examples/onnx/onnx_export.ipynb +++ b/docs/examples/onnx/onnx_export.ipynb @@ -80,7 +80,7 @@ "model = Model().eval().cuda()\n", "inps = (torch.randn([S, B, H], device=\"cuda\"),)\n", "def _inference(fp8_enabled):\n", - " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n", + " with torch.no_grad(), te.pytorch.autocast(enabled=fp8_enabled):\n", " model(*inps)\n", "\n", "te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n", @@ -138,7 +138,7 @@ "from transformer_engine.pytorch.export import te_translation_table\n", "\n", "def export(model, fname, inputs, fp8=True):\n", - " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n", + " with torch.no_grad(), te.pytorch.autocast(enabled=fp8):\n", " # ! IMPORTANT !\n", " # Transformer Engine models must have warm-up run\n", " # before export. FP8 recipe during warm-up should \n", diff --git a/docs/examples/quickstart.ipynb b/docs/examples/quickstart.ipynb index 3b50ce161..0ad2f4fee 100644 --- a/docs/examples/quickstart.ipynb +++ b/docs/examples/quickstart.ipynb @@ -548,7 +548,7 @@ "\n", "
\n", "\n", - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [fp8_autocast](../api/pytorch.rst#transformer_engine.pytorch.fp8_autocast) context manager. Note that fp8_autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." + "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." ] }, { @@ -567,7 +567,7 @@ "fp8_format = Format.HYBRID\n", "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", "torch.manual_seed(1234)\n", - "with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):\n", + "with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n", " y = te_transformer(x, attention_mask=None)" ] }, @@ -591,7 +591,7 @@ " x,\n", " dy,\n", " forward_kwargs = { \"attention_mask\": None },\n", - " fp8_autocast_kwargs = { \"enabled\": True, \"fp8_recipe\": fp8_recipe },\n", + " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", ")" ] } diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index f7a81d4d8..473fce7fe 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -13,7 +13,7 @@ def speedometer( input: torch.Tensor, output_grad: torch.Tensor, forward_kwargs: dict = {}, - fp8_autocast_kwargs: Optional[dict] = None, + autocast_kwargs: Optional[dict] = None, timing_iters: int = 50, warmup_iters: int = 50, ) -> None: @@ -23,20 +23,20 @@ def speedometer( """ start = torch.cuda.Event(enable_timing=True) end = torch.cuda.Event(enable_timing=True) - if fp8_autocast_kwargs is None: - fp8_autocast_kwargs = {"enabled": False} + if autocast_kwargs is None: + autocast_kwargs = {"enabled": False} # Warmup runs torch.cuda.synchronize() for _ in range(warmup_iters): - with te.fp8_autocast(**fp8_autocast_kwargs): + with te.autocast(**autocast_kwargs): output = module(input, **forward_kwargs) output.backward(output_grad) # Timing runs start.record() for _ in range(timing_iters): - with te.fp8_autocast(**fp8_autocast_kwargs): + with te.autocast(**autocast_kwargs): output = module(input, **forward_kwargs) output.backward(output_grad) end.record() diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py index 6285fea1a..d3de8a185 100755 --- a/docs/examples/te_gemma/te_gemma.py +++ b/docs/examples/te_gemma/te_gemma.py @@ -14,7 +14,7 @@ import transformer_engine as te from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding from transformer_engine.common.recipe import Format, DelayedScaling -from transformer_engine.pytorch.fp8 import get_default_fp8_recipe +from transformer_engine.pytorch.quantization import get_default_fp8_recipe import transformers from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel @@ -461,8 +461,8 @@ def generate( # Both autocasts are needed: FP8 for operations that can run in lower # precision and BF16 for those that cannot. - with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( - enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.autocast( + enabled=self.config.fp8, recipe=self.fp8_recipe if self.config.fp8 else None ): lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() # If padding is at the beginning, then shift it to the end @@ -694,8 +694,8 @@ def record_graph(self, function, input_tensor, **sample_kwargs): graphed_function = te.pytorch.make_graphed_callables( function, (input_tensor,), - fp8_enabled=self.config.fp8, - fp8_recipe=fp8_recipe, + enabled=self.config.fp8, + recipe=fp8_recipe, allow_unused_input=True, num_warmup_iters=5, sample_kwargs=sample_kwargs, diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py index d0df9edc5..36b0a5b73 100755 --- a/docs/examples/te_gemma/te_gemma_loading_weights.py +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -9,7 +9,7 @@ from typing import List -from transformer_engine.pytorch.fp8 import fp8_model_init +from transformer_engine.pytorch.quantization import quantized_model_init from transformers.modeling_utils import load_state_dict from transformers.utils.hub import get_checkpoint_shard_files @@ -88,10 +88,10 @@ def load_te_model(cls, config): config.use_cache = False # To make TransformerLayer compatible with GemmaModel # Loading model with FP8 only weights needs both the following context managers. - # 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights. + # 1. quantized_model_init(config.quantized_model_init) to tell TE to use FP8 only weights. # 2. torch.no_grad() during TE modules' initilization so that they respect - # the `fp8_model_init` context manager. - with torch.no_grad(), fp8_model_init(config.fp8_model_init): + # the `quantized_model_init` context manager. + with torch.no_grad(), quantized_model_init(config.quantized_model_init): # Just create a model with random weights. vanilla_model = cls(config).cuda() diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb index cc8675cfd..c31e272b2 100755 --- a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -77,7 +77,7 @@ "\n", "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n", "\n", - "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", + "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", "\n", "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", "\n", @@ -94,12 +94,12 @@ "\n", "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n", "\n", - "The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", + "The Transformer Engine includes a wrapper `quantized_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", "\n", "
\n", "\"\"\n", "
\n", - "Figure 3: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using fp8_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", + "Figure 3: Model under autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using quantized_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", "
\n", "
\n", "\n", @@ -405,8 +405,8 @@ " graphed_function = te.pytorch.make_graphed_callables(\n", " function,\n", " (input_tensor,),\n", - " fp8_enabled=self.config.fp8,\n", - " fp8_recipe=fp8_recipe,\n", + " enabled=self.config.fp8,\n", + " recipe=fp8_recipe,\n", " allow_unused_input=True,\n", " num_warmup_iters=5,\n", " sample_kwargs=sample_kwargs,\n", @@ -540,14 +540,14 @@ "source": [ "### Calibrating FP8 scaling factors for correctness\n", "\n", - "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", + "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", "\n", "1. Model weight tensors\n", "2. Input tensors\n", "\n", "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n", "\n", - "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", + "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", "\n", "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n", " \n", @@ -590,14 +590,14 @@ "model = init_te_gemma_model(run_config)\n", "\n", "# Calibration\n", - "with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n", + "with te.autocast(enabled=False, calibrating=True), torch.autocast(\n", " device_type=\"cuda\", dtype=torch.bfloat16\n", "):\n", " model.train()\n", " run_forward_pass(model, run_config, num_iters=64)\n", "\n", "# Compute scale_fwd with enabled fp8 autocast\n", - "with te.fp8_autocast(enabled=True), torch.autocast(\n", + "with te.autocast(enabled=True), torch.autocast(\n", " device_type=\"cuda\", dtype=torch.bfloat16\n", "):\n", " run_forward_pass(model, run_config, 1)\n", @@ -734,7 +734,7 @@ "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n", "\n", "\n", - "Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:" + "Transformer Engine supports maintaining FP8-only weights with the `quantized_model_init` context manager. Let's see a small example:" ] }, { @@ -778,7 +778,7 @@ "del linear_bf16\n", "\n", "# Initialize model weights in FP8 precision\n", - "with torch.no_grad(), te.fp8_model_init(enabled=True):\n", + "with torch.no_grad(), te.quantized_model_init(enabled=True):\n", " linear_fp8 = te.Linear(H, D)\n", "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", "del linear_fp8" @@ -793,11 +793,11 @@ "
\n", "\n", "
\n", - " Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", + " Figure 8: Using quantized_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", "
\n", "
\n", "\n", - "Let's run the code with `fp8_model_init`:" + "Let's run the code with `quantized_model_init`:" ] }, { @@ -862,7 +862,7 @@ "\n", "# Enable FP8 math and FP8 model weights\n", "run_config.fp8 = True\n", - "run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "run_config.quantized_model_init = True # This will result in storing only fp8 weights.\n", "run_config.fp8_model_weights_filename = (\n", " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", ")\n", @@ -885,7 +885,7 @@ "| HF (baseline) | 46.6 s | - | - | - |\n", "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n", - "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" + "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `quantized_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" ] }, { @@ -911,7 +911,7 @@ "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n", "\n", "1. Longer context lengths (with paged KV cache) \n", - "2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n", + "2. Using less memory during generation (by storing weights in FP8 precision using `quantized_model_init`)\n", "\n", "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models." ] diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py index cc31afc65..9b67f178f 100755 --- a/docs/examples/te_gemma/utils.py +++ b/docs/examples/te_gemma/utils.py @@ -34,7 +34,7 @@ def __init__(self): # FP8 precision settings self.fp8 = False self.fp8_model_weights_filename = None - self.fp8_model_init = False + self.quantized_model_init = False # Cuda graphs self.generation_cuda_graphs = False diff --git a/docs/faq.rst b/docs/faq.rst index 2f9cbd272..a9406ed45 100644 --- a/docs/faq.rst +++ b/docs/faq.rst @@ -15,8 +15,8 @@ Here, we take the `MultiheadAttention` module as an example. Its FP8 attention m .. code-block:: python - >>> from transformer_engine.pytorch import MultiheadAttention, fp8_model_init - >>> with fp8_model_init(enabled=True): + >>> from transformer_engine.pytorch import MultiheadAttention, quantized_model_init + >>> with quantized_model_init(enabled=True): ... mha = MultiheadAttention( ... hidden_size=1024, ... num_attention_heads=16, diff --git a/docs/installation.rst b/docs/installation.rst index ecb1e9a0d..a8bb74fd1 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -38,6 +38,14 @@ Transformer Engine can be directly installed from `our PyPI (atol + rtol * jnp.abs(gathered_output)) + print(mask.astype(int)) + print(jnp.where(mask, diff, 0)) + + +# Shared constants for all tests +DP_AXIS = "data" +TPSP_AXIS = "tensor_sequence" +PARAMS_KEY = "params" + +# Shared functions for distributed testing +import argparse +import jax +from jax.experimental import mesh_utils +from transformer_engine.jax.cpp_extensions.gemm import collective_gemm_bootstrap + +# Global flag to track if distributed has been initialized +_distributed_initialized = False + + +def _is_distributed_initialized(): + """Check if JAX distributed has been initialized.""" + return _distributed_initialized + + +def _initialize_distributed(args): + """Initialize JAX distributed with custom arguments.""" + global _distributed_initialized + + # Check if already initialized + if _distributed_initialized: + return + + if args.coordinator_address is None or args.num_processes is None or args.process_id is None: + raise ValueError( + "All distributed initialization arguments are required: " + "--coordinator-address, --num-processes, --process-id" + ) + if args.local_device_ids is None: + assert ( + args.num_devices_per_process is not None + ), "Either local_device_ids or num_devices_per_process must be provided" + # Calculate device range for this process + # Single process single device: each process gets one unique device + # Single process multiple devices: each process gets a unique range of devices + start_device = args.process_id * args.num_devices_per_process + device_range = range(start_device, start_device + args.num_devices_per_process) + global_device_ids_for_this_process = ",".join(map(str, device_range)) + else: + # Use explicitly provided global device IDs + global_device_ids_for_this_process = args.local_device_ids + args.num_devices_per_process = len(args.local_device_ids.split(",")) + + assert args.num_devices_per_process == 1, "Only single process single GPU is supported!" + + print( + f"Initializing JAX distributed with coordinator={args.coordinator_address}, " + f"num_processes={args.num_processes}, process_id={args.process_id}" + ) + # Note: "local_device_ids" is a JAX term meaning "global CUDA devices managed by this process" + jax.distributed.initialize( + coordinator_address=args.coordinator_address, + num_processes=args.num_processes, + process_id=args.process_id, + local_device_ids=global_device_ids_for_this_process, + ) + + _distributed_initialized = True + jax.clear_caches() + jax.config.update( + "jax_use_shardy_partitioner", False + ) # CollectiveGEMM does not work with Shardy yet + + assert jax.local_device_count() == 1, ( + f"[{args.process_id}|{args.num_devices_per_process}] Expected 1 GPU per process, found" + f" {jax.local_device_count()}" + ) + + devices_per_process = 1 + num_total_devices = args.num_processes + + print( + f"Initializing CGEMM communicator with num_total_devices={num_total_devices}," + f" devices_per_process={devices_per_process}, process_id={args.process_id}" + ) + + collective_gemm_bootstrap( + num_total_devices=num_total_devices, + num_devices_per_process=devices_per_process, + process_id=args.process_id, + tensor_parallel_size=args.tensor_parallel_size, + ) + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +def _create_mesh(args): + """Create mesh configuration with proper validation.""" + num_gpu = args.num_processes * args.num_devices_per_process + assert num_gpu == len(jax.devices()), "Number of GPUs must be equal to number of devices" + num_gpu_dp, num_gpu_tp = _get_dp_and_tp_sizes(args) + + print(f"Using {num_gpu_dp}x{num_gpu_tp} mesh ({num_gpu_dp * num_gpu_tp} total GPUs)") + + device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) + mesh = jax.sharding.Mesh(devices=device_mesh, axis_names=(DP_AXIS, TPSP_AXIS)) + return mesh + + +def cgemm_parser(description="Collective GEMM test on multi-GPU with tensor parallelism"): + """Create common argument parser for all collective GEMM tests.""" + parser = argparse.ArgumentParser(description=description) + + # Distributed initialization arguments + parser.add_argument( + "--coordinator-address", + type=str, + default=None, + help="Coordinator address for distributed initialization", + ) + parser.add_argument( + "--num-processes", + type=int, + default=None, + help="Number of processes for distributed initialization", + ) + parser.add_argument( + "--process-id", type=int, default=None, help="Process ID for distributed initialization" + ) + parser.add_argument( + "--local-device-ids", + type=str, + default=None, + help="Local device IDs for distributed initialization (comma-separated)", + ) + parser.add_argument( + "--num-devices-per-process", type=int, default=1, help="Number of devices per process" + ) + + # Test configuration arguments + parser.add_argument( + "--tensor-parallel-size", type=int, default=None, help="Tensor parallel size" + ) + parser.add_argument("--batch-size", type=int, default=4, help="Batch size for testing") + parser.add_argument("--seq-len", type=int, default=8192, help="Sequence length for testing") + parser.add_argument("--hidden-in", type=int, default=4096, help="Input hidden dimension") + parser.add_argument("--hidden-out", type=int, default=8192, help="Output hidden dimension") + parser.add_argument( + "--collective-type", + type=str, + default="all_gather", + choices=["all_gather", "reduce_scatter"], + help="Type of collective operation", + ) + parser.add_argument( + "--fp8-recipe", type=str, default="DelayedScaling", help="FP8 recipe to use" + ) + parser.add_argument( + "--enable-data-parallel", action="store_true", help="Enable data parallelism" + ) + parser.add_argument( + "--enable-result-check", action="store_true", default=True, help="Enable result checking" + ) + + return parser diff --git a/examples/jax/collective_gemm/conftest.py b/examples/jax/collective_gemm/conftest.py new file mode 100644 index 000000000..83937971a --- /dev/null +++ b/examples/jax/collective_gemm/conftest.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""config for collective_gemm tests""" +import pytest + + +def pytest_addoption(parser): + """Pytest hook for collective_gemm tests""" + parser.addoption("--coordinator-address", action="store", default="localhost:12345") + parser.addoption("--num-processes", action="store", default=1) + parser.addoption("--process-id", action="store", default=0) + parser.addoption("--local-device-ids", action="store", default=None) + + +@pytest.fixture(autouse=True) +def distributed_args(request): + """Fixture for querying distributed initialization arguments""" + if request.cls: + request.cls.coordinator_address = request.config.getoption("--coordinator-address") + request.cls.num_processes = int(request.config.getoption("--num-processes")) + request.cls.process_id = int(request.config.getoption("--process-id")) + request.cls.local_device_ids = request.config.getoption("--local-device-ids") + request.cls.num_devices_per_process = ( + 1 + if request.cls.local_device_ids is None + else len(request.cls.local_device_ids.split(",")) + ) diff --git a/examples/jax/collective_gemm/run_test_cgemm.sh b/examples/jax/collective_gemm/run_test_cgemm.sh new file mode 100644 index 000000000..af263eb53 --- /dev/null +++ b/examples/jax/collective_gemm/run_test_cgemm.sh @@ -0,0 +1,119 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +# Check if NVLINK is supported before running tests +echo "*** Checking NVLINK support***" +NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) +NVLINK_EXIT_CODE=$? + +# Check if command failed OR output indicates no NVLINK +if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then + echo "NVLINK is not supported on this platform" + echo "Collective GEMM tests require NVLINK connectivity" + echo "SKIPPING all tests" + exit 0 +else + echo "NVLINK support detected" +fi + +# Define the test files to run +TEST_FILES=( +"test_gemm.py" +"test_dense_grad.py" +"test_layernorm_mlp_grad.py" +) + +echo +echo "*** Executing tests in examples/jax/collective_gemm/ ***" + +HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM + +# Run each test file across all GPUs +for TEST_FILE in "${TEST_FILES[@]}"; do + echo + echo "=== Starting test file: $TEST_FILE ..." + + # Clear PIDs array for this test file + PIDS=() + + for i in $(seq 0 $(($NUM_GPUS - 1))); do + # Define output file for logs + LOG_FILE="${TEST_FILE}_gpu_${i}.log" + + if [ $i -eq 0 ]; then + # For process 0: show live output AND save to log file using tee + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \ + "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + # For other processes: redirect to log files only + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ + --num-processes=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done + + # Wait for all processes to finish + wait + + # Check and print the log content from process 0 (now has log file thanks to tee) + if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE SKIPPED" + elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE FAILED" + HAS_FAILURE=1 + elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then + echo "... $TEST_FILE PASSED" + else + echo "... $TEST_FILE INVALID" + HAS_FAILURE=1 + fi + + # Remove the log files after processing them + wait + rm ${TEST_FILE}_gpu_*.log +done + +wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + +exit $HAS_FAILURE diff --git a/examples/jax/collective_gemm/test_dense_grad.py b/examples/jax/collective_gemm/test_dense_grad.py new file mode 100644 index 000000000..e14329d48 --- /dev/null +++ b/examples/jax/collective_gemm/test_dense_grad.py @@ -0,0 +1,214 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.dense import dense + +from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOp, + CollectiveOpSet, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(collective_op): + if collective_op.is_all_gather: + input_axes = (DP_AXIS, TPSP_AXIS, None) + weight_axes = (None, TPSP_AXIS) + bias_axes = (TPSP_AXIS,) + output_axes = (DP_AXIS, None, TPSP_AXIS) + else: # RS + input_axes = (DP_AXIS, None, TPSP_AXIS) + weight_axes = (TPSP_AXIS, None) + bias_axes = (None,) + output_axes = (DP_AXIS, TPSP_AXIS, None) + return input_axes, weight_axes, bias_axes, output_axes + + +def _get_operand_sharding(mesh, collective_op): + input_axes, weight_axes, bias_axes, _ = _get_logical_axes(collective_op) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_axes)) + weight_sharding = NamedSharding(mesh, PartitionSpec(*weight_axes)) + bias_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes)) + return x_sharding, weight_sharding, bias_sharding + + +def _mean_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + output = dense( + x, + weight, + bias, + contracting_dims=((2,), (0,)), + input_axes=input_axes, + kernel_axes=weight_axes, + output_axes=output_axes, + collective_op_set=collective_op_set, + ) + return jnp.mean(output.astype(jnp.float32)) + + +def _value_and_grad_dense(x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set): + return jax.jit(jax.value_and_grad(_mean_dense, (0, 1, 2)), static_argnums=(3, 4, 5, 6))( + x, weight, bias, input_axes, weight_axes, output_axes, collective_op_set + ) + + +def run_dense_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + collective_op_set = CollectiveOpSet.create(forward_collective_op=collective_op) + + with mesh, autocast( + enabled=False, + recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + + x_sharding, weight_sharding, bias_sharding = _get_operand_sharding(mesh, collective_op) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + input_axes, weight_axes, _, output_axes = _get_logical_axes(collective_op) + ref_output, ref_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + noop_collective_op_set, + ) + output, sharded_grads = _value_and_grad_dense( + x_sharded, + weight_sharded, + bias_sharded, + input_axes, + weight_axes, + output_axes, + collective_op_set, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveDenseGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather(self): + """Test Collective Dense Gradient with AllGather""" + self.args.collective_type = "all_gather" + run_dense_grad_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter(self): + """Test Collective Dense Gradient with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_dense_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_dense_grad.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_dense_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective Dense Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_dense_grad_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_gemm.py b/examples/jax/collective_gemm/test_gemm.py new file mode 100644 index 000000000..ac86c551d --- /dev/null +++ b/examples/jax/collective_gemm/test_gemm.py @@ -0,0 +1,206 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective GEMM test on multi-GPU with tensor parallelism + +This script uses custom distributed initialization with the following arguments: +- --coordinator-address: Coordinator address for distributed initialization +- --num-processes: Number of processes for distributed initialization +- --process-id: Process ID for distributed initialization +- --local-device-ids: Local device IDs for distributed initialization + +Example: + python test_gemm.py --coordinator-address localhost:1234 --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3 +""" +import unittest +import os +from functools import partial + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.cpp_extensions.gemm import CollectiveOp +from transformer_engine.jax.sharding import MeshResource + + +def _get_operand_sharding(mesh, collective_op, is_with_dp): + + dp_axis = DP_AXIS if is_with_dp else None + if collective_op == CollectiveOp.ALL_GATHER: + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + weight_sharding = NamedSharding(mesh, PartitionSpec(None, TPSP_AXIS)) + bias_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + else: # RS + x_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, None, TPSP_AXIS)) + weight_sharding = NamedSharding(mesh, PartitionSpec(TPSP_AXIS, None)) + bias_sharding = NamedSharding(mesh, PartitionSpec(None)) + output_sharding = NamedSharding(mesh, PartitionSpec(dp_axis, TPSP_AXIS, None)) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +def _get_dp_and_tp_sizes(args): + num_gpu = args.num_processes * args.num_devices_per_process + if args.tensor_parallel_size is None: + num_gpu_dp = 2 if args.enable_data_parallel else 1 + assert ( + num_gpu > 1 and num_gpu % num_gpu_dp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_tp = num_gpu // num_gpu_dp + else: + num_gpu_tp = args.tensor_parallel_size + assert ( + num_gpu > 1 and num_gpu % num_gpu_tp == 0 + ), "Number of GPUs must be greater than 1 and divisible by number of data parallel GPUs" + num_gpu_dp = num_gpu // num_gpu_tp + return num_gpu_dp, num_gpu_tp + + +@partial(jax.jit, static_argnames=("contracting_dims", "collective_op", "output_sharding")) +def _jitted_cgemm(x, weight, bias, contracting_dims, collective_op, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + collective_op=collective_op, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +def run_gemm_tests(args, mesh=None): + """Execute GEMM tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_rng, bias_rng = jax.random.split(rng, 4) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight = jax.random.normal(weight_rng, (args.hidden_in, args.hidden_out), dtype=jnp.bfloat16) + bias = jax.random.normal(bias_rng, (args.hidden_out,), dtype=jnp.bfloat16) + collective_op = ( + CollectiveOp.ALL_GATHER + if args.collective_type == "all_gather" + else CollectiveOp.REDUCE_SCATTER + ) + + with mesh, autocast( + enabled=False, + recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + print(f"Device mesh: {mesh}") + + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_operand_sharding( + mesh, collective_op, args.enable_data_parallel + ) + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + ref_output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=CollectiveOp.NONE, + output_sharding=output_sharding, + ) + output = _jitted_cgemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=((2,), (0,)), + collective_op=collective_op, + # CollectiveGEMM output should have a correct sharding without applying sharding constraint + output_sharding=None, + ) + assert ( + ref_output.sharding == output.sharding + ), f"ref_output.sharding={ref_output.sharding}, output.sharding={output.sharding}" + gathered_ref_output = jax.lax.with_sharding_constraint( + ref_output, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_output = jax.lax.with_sharding_constraint( + output, NamedSharding(mesh, PartitionSpec(None)) + ) + jax.block_until_ready(gathered_ref_output) + jax.block_until_ready(gathered_output) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(gathered_ref_output, gathered_output) + + +class TestCollectiveGemmWithDP(unittest.TestCase): + """Collective GEMM with DP unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective GEMM test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_all_gather_with_dp(self): + """Test Collective GEMM with AllGather""" + self.args.collective_type = "all_gather" + run_gemm_tests(self.args, self.mesh) + + def test_te_bf16_reduce_scatter_with_dp(self): + """Test Collective GEMM with ReduceScatter""" + self.args.collective_type = "reduce_scatter" + run_gemm_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 5: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_gemm.py --coordinator-address
--num-processes " + " --process-id [--local-device-ids ] [other args]" + ) + sys.exit(1) + + args = cgemm_parser("Collective GEMM test on multi-GPU with tensor parallelism").parse_args() + _initialize_distributed(args) + run_gemm_tests(args, mesh=None) diff --git a/examples/jax/collective_gemm/test_layernorm_mlp_grad.py b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py new file mode 100644 index 000000000..407cec68a --- /dev/null +++ b/examples/jax/collective_gemm/test_layernorm_mlp_grad.py @@ -0,0 +1,272 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Collective Dense Gradient test on multi-GPU with tensor parallelism""" +import argparse +import unittest +import os + +import jax +import jax.numpy as jnp +from jax.sharding import PartitionSpec, NamedSharding +import flax + +from common import ( + assert_allclose, + _initialize_distributed, + _get_dp_and_tp_sizes, + _create_mesh, + DP_AXIS, + TPSP_AXIS, + PARAMS_KEY, + cgemm_parser, +) + +from transformer_engine.jax.layernorm_mlp import layernorm_mlp + +from transformer_engine.jax.quantize import autocast +from transformer_engine.jax.cpp_extensions.gemm import ( + CollectiveOpSet, + CollectiveOp, + noop_collective_op_set, +) +from transformer_engine.jax.sharding import MeshResource +import transformer_engine.jax.flax as te_flax + + +def _get_logical_axes(): + input_1_axes = (DP_AXIS, TPSP_AXIS, None) + weight_1_axes = (None, None, TPSP_AXIS) + bias_axes_1 = (None, TPSP_AXIS) + input_2_axes = (DP_AXIS, None, TPSP_AXIS) + weight_2_axes = (TPSP_AXIS, None) + bias_axes_2 = (None,) + return input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 + + +def _get_operand_sharding(mesh): + input_1_axes, weight_1_axes, bias_axes_1, input_2_axes, weight_2_axes, bias_axes_2 = ( + _get_logical_axes() + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*input_1_axes)) + weight_1_sharding = NamedSharding(mesh, PartitionSpec(*weight_1_axes)) + bias_1_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_1)) + weight_2_sharding = NamedSharding(mesh, PartitionSpec(*weight_2_axes)) + bias_2_sharding = NamedSharding(mesh, PartitionSpec(*bias_axes_2)) + return x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding + + +def _mean_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + output = layernorm_mlp( + x, + gamma, + beta=None, + kernels=[weight_1, weight_2], + biases=[bias_1, bias_2], + norm_type="rmsnorm", + dot_1_input_axes=input_1_axes, + dot_2_input_axes=input_2_axes, + kernel_1_axes=weight_1_axes, + kernel_2_axes=weight_2_axes, + activation_type=("gelu",), + collective_op_sets=collective_op_sets, + ) + return jnp.mean(output) + + +def _value_and_grad_layernorm_mlp( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, +): + return jax.jit( + jax.value_and_grad(_mean_layernorm_mlp, (0, 1, 2, 3, 4, 5)), static_argnums=(6, 7, 8, 9, 10) + )( + x, + weight_1, + bias_1, + weight_2, + bias_2, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + + +def run_layernorm_mlp_grad_tests(args, mesh=None): + """Execute Dense Gradient tests.""" + print(args) + # Collective GEMM requires Shardy partitioner to be disabled + jax.config.update("jax_use_shardy_partitioner", False) + + # Initialize distributed with provided arguments + _initialize_distributed(args) + + mesh = mesh or _create_mesh(args) + + # Create test data + rng = jax.random.PRNGKey(0) + rng, x_rng, weight_1_rng, bias_1_rng, weight_2_rng, bias_2_rng, gamma_rng = jax.random.split( + rng, 7 + ) + x = jax.random.normal( + x_rng, (args.batch_size, args.seq_len, args.hidden_in), dtype=jnp.bfloat16 + ) + weight_1 = jax.random.normal( + weight_1_rng, (args.hidden_in, 1, args.hidden_out), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_in) + bias_1 = jax.random.normal(bias_1_rng, (1, args.hidden_out), dtype=jnp.bfloat16) + weight_2 = jax.random.normal( + weight_2_rng, (args.hidden_out, args.hidden_in), dtype=jnp.bfloat16 + ) / jnp.sqrt(args.hidden_out) + bias_2 = jax.random.normal(bias_2_rng, (args.hidden_in,), dtype=jnp.bfloat16) + gamma = jax.random.normal(gamma_rng, (args.hidden_in,), dtype=jnp.bfloat16) / jnp.sqrt( + args.hidden_in + ) + collective_op_set_1 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.ALL_GATHER) + collective_op_set_2 = CollectiveOpSet.create(forward_collective_op=CollectiveOp.REDUCE_SCATTER) + collective_op_sets = (collective_op_set_1, collective_op_set_2) + noop_collective_op_sets = (noop_collective_op_set, noop_collective_op_set) + + with mesh, autocast( + enabled=False, + recipe=None, + mesh_resource=MeshResource(dp_resource=DP_AXIS, tpsp_resource=TPSP_AXIS), + ): + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast + axis_rules = flax.linen.get_logical_axis_rules() + axis_rules += ((TPSP_AXIS, TPSP_AXIS), (DP_AXIS, DP_AXIS)) + te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) + with flax.linen.logical_axis_rules(te_extended_axis_rules): + x_sharding, weight_1_sharding, bias_1_sharding, weight_2_sharding, bias_2_sharding = ( + _get_operand_sharding(mesh) + ) + x_sharded = jax.device_put(x, x_sharding) + weight_1_sharded = jax.device_put(weight_1, weight_1_sharding) + bias_1_sharded = jax.device_put(bias_1, bias_1_sharding) + weight_2_sharded = jax.device_put(weight_2, weight_2_sharding) + bias_2_sharded = jax.device_put(bias_2, bias_2_sharding) + + input_1_axes, weight_1_axes, _, input_2_axes, weight_2_axes, _ = _get_logical_axes() + ref_output, ref_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + noop_collective_op_sets, + ) + output, sharded_grads = _value_and_grad_layernorm_mlp( + x_sharded, + weight_1_sharded, + bias_1_sharded, + weight_2_sharded, + bias_2_sharded, + gamma, + input_1_axes, + input_2_axes, + weight_1_axes, + weight_2_axes, + collective_op_sets, + ) + jax.block_until_ready(ref_output) + jax.block_until_ready(output) + gathered_grads = [] + gathered_ref_grads = [] + for ref_grad, grad in zip(ref_grads, sharded_grads): + gathered_grads.append( + jax.lax.with_sharding_constraint(grad, NamedSharding(mesh, PartitionSpec(None))) + ) + gathered_ref_grads.append( + jax.lax.with_sharding_constraint(ref_grad, NamedSharding(mesh, PartitionSpec(None))) + ) + jax.block_until_ready(gathered_grads) + jax.block_until_ready(gathered_ref_grads) + + if args.enable_result_check and args.process_id == 0: + assert_allclose(ref_output, output, dtype=jnp.bfloat16) + for ref_grad, gathered_grad in zip(gathered_ref_grads, gathered_grads): + assert_allclose(ref_grad, gathered_grad, dtype=jnp.bfloat16) + + +class TestCollectiveLayerNormMLPGradient(unittest.TestCase): + """Collective Dense Gradient unittests""" + + def setUp(self): + self.args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + self.args.coordinator_address = self.coordinator_address + self.args.num_processes = self.num_processes + self.args.process_id = self.process_id + self.args.local_device_ids = self.local_device_ids + self.args.num_devices_per_process = self.num_devices_per_process + self.args.enable_data_parallel = True + self.args.tensor_parallel_size = _get_dp_and_tp_sizes(self.args)[1] + _initialize_distributed(self.args) + # Create mesh once for all tests + self.mesh = _create_mesh(self.args) + jax.sharding.set_mesh(self.mesh) + self.args.enable_result_check = True + os.environ["NVTE_JAX_ALL_REDUCE_IN_FP32"] = "1" + + def tearDown(self): + os.environ.pop("NVTE_JAX_ALL_REDUCE_IN_FP32", None) + + def test_te_bf16_layernorm_mlp_grad(self): + """Test Collective Dense Gradient with AllGather""" + run_layernorm_mlp_grad_tests(self.args, self.mesh) + + +if __name__ == "__main__": + import sys + + if len(sys.argv) < 7: # Need at least the 3 required distributed args + print("Error: This script requires distributed initialization arguments.") + print( + "Usage: python test_layernorm_mlp_grad.py --coordinator-address
" + " --num-processes --process-id [--local-device-ids ] [other args]" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 4 --process-id 0" + ) + print( + "Example: python test_layernorm_mlp_grad.py --coordinator-address localhost:1234" + " --num-processes 2 --process-id 0 --local-device-ids 0,1,2,3" + ) + sys.exit(1) + + args = cgemm_parser( + "Collective LayerNorm MLP Gradient test on multi-GPU with tensor parallelism" + ).parse_args([]) + _initialize_distributed(args) + run_layernorm_mlp_grad_tests(args, mesh=None) diff --git a/examples/jax/encoder/README.md b/examples/jax/encoder/README.md index 575f7be6e..fc6696317 100644 --- a/examples/jax/encoder/README.md +++ b/examples/jax/encoder/README.md @@ -8,7 +8,7 @@ This example uses Transformer Encoder to demonstrate the Transformer Engine usag 2. Define model: The `Net` class is a small Transformer Encoder model for sentence classification. The Transformer Engine provides `te.TransformerLayer` as encoder block and `te.DenseGeneral`. The structure of encoder block can be referred to [Scaling Up Models and Data with t5x and seqio](https://arxiv.org/abs/2203.17189) -3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `fp8_autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`. +3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. Use `autocast` context manager to enable FP8 training and check `var_collect` if the variable collection contains `Float8`. 4. Training process: In `train_step`, combine the FP8 metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. @@ -29,7 +29,7 @@ python test_single_gpu_encoder.py --use-fp8 3. On the model side, the logical axis of each weight tensor of the model can be named. The `te.TransformerLayer` has the default names, which are stored in `abs_var_collect`, a collection of variables returned by `jax.eval_shape(encoder.init, ...)`. The key index is `params_axes`. The `te.DenseGeneral` doesn't have the default named axis because it is generic. Also, data-parallel sharding doesn't need to divide weight tensor, so named axis is not required for this case. -4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under fp8_autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. +4. The next is to create sharding rules, mapping the device axis to the logical axis. The `te.extend_logical_axis_rules` under autocast will return a list of pairs of the mapping, such as `(('batch', 'data'), ...)`. The first is the logical axis and second is the device axis. 5. Refer structure of `abs_var_collect['params']` and `abs_var_collect['params_axes']` to set up `PartitionSpec` for parallel jit. All logical axes should be replaced by device axes. If the value of PartitionSpec is None, that means no sharding, broadcasting the data to every device. Note that the `params_axes` attribute is provided by Transformer Engine. The Flax's module doesn't have it, such as `nn.Embed`. For nn.Embed, assigning an empty PartitionSpec is fine because each device has its own embedding layer in DP mode. The `get_params_pspec` routine is used for this purpose. Because each device has a complete model in DP mode, all values of PartitionSpec in params_pspec should be None. This will be different in the model parallelism example. @@ -136,4 +136,4 @@ numactl --cpunodebind=112 --membind=7 python test_multiprocessing_encoder.py --n numactl --cpunodebind=113 --membind=7 python test_multiprocessing_encoder.py --num-process 8 --process-id 5 & numactl --cpunodebind=80 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 6 & numactl --cpunodebind=81 --membind=5 python test_multiprocessing_encoder.py --num-process 8 --process-id 7 & -``` \ No newline at end of file +``` diff --git a/examples/jax/encoder/common.py b/examples/jax/encoder/common.py index 98c984839..dd7caff8d 100644 --- a/examples/jax/encoder/common.py +++ b/examples/jax/encoder/common.py @@ -5,6 +5,9 @@ # See LICENSE for license information. """Shared functions for the encoder tests""" from functools import lru_cache +import os +import pathlib +import zipfile import jax import jax.numpy @@ -47,6 +50,13 @@ def is_mxfp8_supported(): return gpu_arch >= 100 +@lru_cache +def is_nvfp4_supported(): + """Return if FP8 has hardware supported""" + gpu_arch = get_device_compute_capability(0) + return gpu_arch >= 100 + + def assert_params_sufficiently_sharded(params, mesh, tolerance=0.01, print_info=False): """Checks whether most params are sharded across sharding axis. @@ -112,7 +122,7 @@ def assert_leaf_sharding(path, arr): ) -def get_fp8_recipe_from_name_string(name: str): +def get_quantization_recipe_from_name_string(name: str): """Query recipe from a given name string""" match name: case "DelayedScaling": @@ -121,5 +131,54 @@ def get_fp8_recipe_from_name_string(name: str): return recipe.MXFP8BlockScaling() case "Float8CurrentScaling": return recipe.Float8CurrentScaling() + case "NVFP4BlockScaling": + return recipe.NVFP4BlockScaling() case _: - raise ValueError(f"Invalid fp8_recipe, got {name}") + raise ValueError(f"Invalid quantization_recipe, got {name}") + + +@lru_cache(maxsize=None) +def _get_example_artifacts_dir() -> pathlib.Path: + """Path to directory with pre-downloaded datasets""" + + # Check environment variable + path = os.getenv("NVTE_TEST_CHECKPOINT_ARTIFACT_PATH") + if path: + return pathlib.Path(path).resolve() + + # Fallback to path in root dir + root_dir = pathlib.Path(__file__).resolve().parent.parent.parent + return root_dir / "artifacts" / "examples" / "jax" + + +def _unpack_cached_dataset(artifacts_dir: pathlib.Path, folder_name: str) -> None: + """Unpack a cached dataset if available""" + dataset_dir = artifacts_dir / folder_name + if not dataset_dir.exists(): + print(f"Cached dataset {folder_name} not found at {dataset_dir}, skipping unpack") + return + + # Disable any HF network calls since the dataset is cached locally + os.environ["HF_HUB_OFFLINE"] = "1" + + for filename in os.listdir(dataset_dir): + filepath = dataset_dir / filename + if not filename.endswith(".zip"): + continue + print(f"Unpacking cached dataset {folder_name} from {filepath}") + + with zipfile.ZipFile(filepath, "r") as zip_ref: + zip_ref.extractall(pathlib.Path.home() / ".cache" / "huggingface") + print( + f"Unpacked cached dataset {folder_name} to" + f" {pathlib.Path.home() / '.cache' / 'huggingface'}" + ) + + +# This is cached so we don't have to unpack datasets multiple times +@lru_cache(maxsize=None) +def unpack_cached_datasets_if_available() -> None: + """Unpack cached datasets if available""" + artifacts_dir = _get_example_artifacts_dir() + _unpack_cached_dataset(artifacts_dir, "mnist") + _unpack_cached_dataset(artifacts_dir, "encoder") diff --git a/examples/jax/encoder/run_test_multiprocessing_encoder.sh b/examples/jax/encoder/run_test_multiprocessing_encoder.sh index 2a1ac0f8f..fa7102cb4 100644 --- a/examples/jax/encoder/run_test_multiprocessing_encoder.sh +++ b/examples/jax/encoder/run_test_multiprocessing_encoder.sh @@ -10,16 +10,44 @@ TEST_CASES=( "test_te_delayed_scaling_fp8" "test_te_current_scaling_fp8" "test_te_mxfp8" +"test_te_nvfp4" "test_te_bf16_shardy" "test_te_delayed_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy" +"test_te_nvfp4_shardy" ) +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + echo echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" HAS_FAILURE=0 # Global failure flag +PIDS=() # Array to store all process PIDs + +# Cleanup function to kill all processes +cleanup() { + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Killing process $pid" + kill -TERM "$pid" 2>/dev/null || true + fi + done + # Wait a bit and force kill if needed + sleep 2 + for pid in "${PIDS[@]}"; do + if kill -0 "$pid" 2>/dev/null; then + echo "Force killing process $pid" + kill -KILL "$pid" 2>/dev/null || true + fi + done +} + +# Set up signal handlers to cleanup on exit +trap cleanup EXIT INT TERM # Run each test case across all GPUs for TEST_CASE in "${TEST_CASES[@]}"; do echo @@ -29,25 +57,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do # Define output file for logs LOG_FILE="${TEST_CASE}_gpu_${i}.log" - # Run pytest and redirect stdout and stderr to the log file - pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ - -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ - --num-process=$NUM_GPUS \ - --process-id=$i > "$LOG_FILE" 2>&1 & - done + # For process 0: show live output AND save to log file using tee + if [ $i -eq 0 ]; then + echo "=== Live output from process 0 ===" + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \ + "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i 2>&1 | tee "$LOG_FILE" & + PID=$! + PIDS+=($PID) + else + pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ + -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ + --num-process=$NUM_GPUS \ + --process-id=$i > "$LOG_FILE" 2>&1 & + PID=$! + PIDS+=($PID) + fi + done # Wait for the process to finish wait - tail -n +7 "${TEST_CASE}_gpu_0.log" # Check and print the log content accordingly if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE SKIPPED" + elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then + echo "... $TEST_CASE FAILED" + HAS_FAILURE=1 elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then echo "... $TEST_CASE PASSED" else + echo "... $TEST_CASE INVALID" HAS_FAILURE=1 - echo "... $TEST_CASE FAILED" fi # Remove the log file after processing it @@ -56,4 +99,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do done wait + +# Final cleanup (trap will also call cleanup on exit) +cleanup + exit $HAS_FAILURE diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 3855db275..2f4ee3dda 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -23,14 +23,16 @@ from common import ( is_bf16_supported, - get_fp8_recipe_from_name_string, + get_quantization_recipe_from_name_string, assert_params_sufficiently_sharded, + unpack_cached_datasets_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode +unpack_cached_datasets_if_available() DEVICE_DP_AXIS = "data" DEVICE_TP_AXIS = "model" @@ -38,6 +40,7 @@ NAMED_TP_AXIS = "my_tp_axis" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +SR_KEY = "sr_rng" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -123,6 +126,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -137,11 +142,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -152,7 +157,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn): +def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -161,11 +166,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -225,7 +232,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -259,16 +266,16 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, te.fp8_autocast( + ) as mesh, te.autocast( enabled=args.use_fp8, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, mesh_resource=te.MeshResource( dp_resource=DEVICE_DP_AXIS, tpsp_resource=DEVICE_TP_AXIS, @@ -277,13 +284,14 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - # Get the base axis rules and extend them with TE's rules. This must be done inside fp8_autocast + # Get the base axis rules and extend them with TE's rules. This must be done inside autocast axis_rules = flax.linen.get_logical_axis_rules() axis_rules += ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) te_extended_axis_rules = te_flax.extend_logical_axis_rules(axis_rules) @@ -357,7 +365,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -369,22 +384,24 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs ) print( @@ -404,16 +421,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for training (default: 128)", + help="input batch size for training (default: 256)", ) parser.add_argument( "--test-batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for testing (default: 128)", + help="input batch size for testing (default: 256)", ) parser.add_argument( "--max-seq-len", @@ -468,8 +485,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 5 epochs for testing""" @@ -479,7 +497,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -487,7 +505,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.361 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -495,14 +513,22 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_with_sp(self): """Test Transformer Engine with BF16 + SP""" self.args.enable_sp = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp(self): @@ -511,7 +537,7 @@ def test_te_delayed_scaling_fp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8_with_sp(self): @@ -520,14 +546,23 @@ def test_te_mxfp8_with_sp(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_with_sp(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -536,7 +571,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_with_sp_shardy(self): @@ -546,24 +581,27 @@ def test_te_delayed_scaling_fp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.361 and actual[1] > 0.84 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_with_sp_shardy(self): """Test Transformer Engine with MXFP8 + SP""" self.args.enable_shardy = True @@ -571,7 +609,17 @@ def test_te_mxfp8_with_sp_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.39 and actual[1] > 0.83 + assert actual[0] < 0.36 and actual[1] > 0.84 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_with_sp_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.enable_sp = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.40 and actual[1] > 0.82 if __name__ == "__main__": diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index d6bfddb3e..04e5bb92f 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -21,17 +21,23 @@ from jax.experimental import mesh_utils from jax.sharding import PartitionSpec, NamedSharding -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import ( + is_bf16_supported, + get_quantization_recipe_from_name_string, + unpack_cached_datasets_if_available, +) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode +unpack_cached_datasets_if_available() DEVICE_DP_AXIS = "data" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" DROPOUT_KEY = "dropout" +SR_KEY = "sr_rng" INPUT_KEY = "input_rng" @@ -99,6 +105,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -113,11 +121,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect, train_fn): return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -128,7 +136,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect, eval_fn): +def eval_model(state, test_ds, batch_size, var_collect, eval_fn, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -137,11 +145,13 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_fn(state, batch_inputs, batch_masks, batch_labels, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -201,7 +211,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -256,29 +266,28 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu,)) - with jax.sharding.Mesh( - devices=device_mesh, axis_names=(DEVICE_DP_AXIS,) - ) as mesh, te.fp8_autocast( + with jax.sharding.Mesh(devices=device_mesh, axis_names=(DEVICE_DP_AXIS,)) as mesh, te.autocast( enabled=args.use_fp8, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS), ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) + rng, sr_rng = jax.random.split(rng) init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] - # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside fp8_autocast + # Add TE logical axis rules to our Flax logical axis rule context. This must be done inside autocast sharding_rules = te_flax.extend_logical_axis_rules(tuple()) with flax.linen.logical_axis_rules(sharding_rules): encoder = Net(num_embed) @@ -324,7 +333,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -336,22 +352,24 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect, jit_train_step ) test_loss, test_accuracy = eval_model( - state, test_ds, args.test_batch_size, var_collect, jit_eval_step + state, test_ds, args.test_batch_size, var_collect, jit_eval_step, rngs ) print( @@ -371,16 +389,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=256, + default=512, metavar="N", - help="input batch size for training (default: 256)", + help="input batch size for training (default: 512)", ) parser.add_argument( "--test-batch-size", type=int, - default=256, + default=512, metavar="N", - help="input batch size for testing (default: 256)", + help="input batch size for testing (default: 512)", ) parser.add_argument( "--max-seq-len", @@ -432,8 +450,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 5 epochs for testing""" @@ -443,7 +462,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -451,7 +470,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -459,7 +478,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.749 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -467,6 +486,14 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) + assert actual[0] < 0.51 and actual[1] > 0.75 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @@ -474,7 +501,7 @@ def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" self.args.enable_shardy = True actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8_shardy(self): @@ -483,7 +510,7 @@ def test_te_delayed_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.75 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8_shardy(self): @@ -492,18 +519,24 @@ def test_te_current_scaling_fp8_shardy(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.52 and actual[1] > 0.74 + assert actual[0] < 0.51 and actual[1] > 0.749 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" self.args.enable_shardy = True self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) + assert actual[0] < 0.51 and actual[1] > 0.75 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + self.args.enable_shardy = True + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) assert actual[0] < 0.52 and actual[1] > 0.74 diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index 420e36ea1..7f3c5c8e1 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -27,12 +27,15 @@ is_bf16_supported, is_fp8_supported, is_mxfp8_supported, - get_fp8_recipe_from_name_string, + is_nvfp4_supported, + get_quantization_recipe_from_name_string, + unpack_cached_datasets_if_available, ) import transformer_engine.jax as te import transformer_engine.jax.cpp_extensions as tex import transformer_engine.jax.flax as te_flax +unpack_cached_datasets_if_available() os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" DEVICE_DP_AXIS = "data" @@ -41,6 +44,7 @@ NAMED_TP_AXIS = "my_tp_axis" PARAMS_KEY = "params" PARAMS_AXES_KEY = PARAMS_KEY + "_axes" +SR_KEY = "sr_rng" DROPOUT_KEY = "dropout" INPUT_KEY = "input_rng" @@ -177,6 +181,8 @@ def train_epoch( epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_input = sentence[perm, ...] batch_mask = mask[perm, ...] batch_label = label[perm, ...] @@ -202,11 +208,11 @@ def train_epoch( return state, avg_loss, avg_accuracy, var_collect -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels, 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -218,7 +224,16 @@ def loss_fn(var_collect, disable_dropout=False): def eval_model( - state, test_ds, batch_size, var_collect, eval_fn, mesh, inputs_pspec, masks_pspec, labels_pspec + state, + test_ds, + batch_size, + var_collect, + eval_fn, + mesh, + inputs_pspec, + masks_pspec, + labels_pspec, + rngs, ): """Evaluation loop.""" global_input_shape, input_named_sharding, sentence = shard_array_wrapper( @@ -235,7 +250,8 @@ def eval_model( all_accuracy = [] for batch_input, batch_mask, batch_label in zip(sentence, mask, label): - + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} shard_input = jax.make_array_from_single_device_arrays( global_input_shape, input_named_sharding, [batch_input] ) @@ -246,7 +262,7 @@ def eval_model( global_label_shape, label_named_sharding, [batch_label] ) - loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect) + loss, accuracy = eval_fn(state, shard_input, shard_mask, shard_label, var_collect, rngs) all_loss.append(loss) all_accuracy.append(accuracy) @@ -305,7 +321,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -374,16 +390,16 @@ def train_and_evaluate(args): ), "Test batch size needs to be multiple of 32 for MXFP8" if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None device_mesh = mesh_utils.create_device_mesh((num_gpu_dp, num_gpu_tp)) with jax.sharding.Mesh( devices=device_mesh, axis_names=(DEVICE_DP_AXIS, DEVICE_TP_AXIS) - ) as mesh, te.fp8_autocast( + ) as mesh, te.autocast( enabled=args.use_fp8, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, mesh_resource=te.MeshResource( dp_resource=DEVICE_DP_AXIS, tpsp_resource=DEVICE_TP_AXIS, @@ -392,7 +408,8 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] @@ -400,7 +417,7 @@ def train_and_evaluate(args): # Create custom Flax logical axis rules for sharding. customized_rules = ((NAMED_BROADCAST_AXIS, None), (NAMED_TP_AXIS, DEVICE_TP_AXIS)) - # Extend the logical axis rules with TE's rules. This must be done inside fp8_autocast. + # Extend the logical axis rules with TE's rules. This must be done inside autocast. sharding_rules = te_flax.extend_logical_axis_rules(customized_rules) with flax.linen.logical_axis_rules(sharding_rules): @@ -446,7 +463,14 @@ def train_and_evaluate(args): train_step, in_shardings=in_shardings, out_shardings=out_shardings ) - in_shardings = (state_sharding, inputs_sharding, masks_sharding, labels_sharding, None) + in_shardings = ( + state_sharding, + inputs_sharding, + masks_sharding, + labels_sharding, + None, + None, + ) out_shardings = (None, None) jit_eval_step = jax.jit( eval_step, in_shardings=in_shardings, out_shardings=out_shardings @@ -458,14 +482,16 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng_state} jit_train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") else: for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, @@ -490,6 +516,7 @@ def train_and_evaluate(args): inputs_pspec, masks_pspec, labels_sharding.spec, + rngs, ) if args.process_id == 0: print( @@ -510,16 +537,16 @@ def encoder_parser(args): parser.add_argument( "--batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for training (default: 128)", + help="input batch size for training (default: 256)", ) parser.add_argument( "--test-batch-size", type=int, - default=128, + default=256, metavar="N", - help="input batch size for testing (default: 128)", + help="input batch size for testing (default: 256)", ) parser.add_argument( "--max-seq-len", @@ -631,7 +658,7 @@ def test_te_delayed_scaling_fp8(self): def test_te_current_scaling_fp8(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling") - assert result[0] < 0.43 and result[1] > 0.80 + assert result[0] < 0.432 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" @@ -641,6 +668,14 @@ def test_te_mxfp8(self): result = self.exec(True, "MXFP8BlockScaling") assert result[0] < 0.43 and result[1] > 0.80 + @unittest.skipIf( + not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" + ) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + result = self.exec(True, "NVFP4BlockScaling") + assert result[0] < 0.451 and result[1] > 0.787 + @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") def test_te_bf16_shardy(self): """Test Transformer Engine with BF16""" @@ -661,19 +696,24 @@ def test_te_delayed_scaling_fp8_shardy(self): def test_te_current_scaling_fp8_shardy(self): """Test Transformer Engine with CurrentScaling FP8""" result = self.exec(True, "Float8CurrentScaling", enable_shardy=True) - assert result[0] < 0.43 and result[1] > 0.80 + assert result[0] < 0.432 and result[1] > 0.80 @unittest.skipIf( not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" ) - @unittest.skipIf( - tex.gemm_uses_jax_dot(), "`jax.nn.scaled_matmul()` does not support the Shardy partitioner." - ) def test_te_mxfp8_shardy(self): """Test Transformer Engine with MXFP8""" result = self.exec(True, "MXFP8BlockScaling", enable_shardy=True) assert result[0] < 0.43 and result[1] > 0.80 + @unittest.skipIf( + not is_nvfp4_supported(), "Device compute capability 10.0+ is required for NVFP4" + ) + def test_te_nvfp4_shardy(self): + """Test Transformer Engine with NVFP4""" + result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) + assert result[0] < 0.451 and result[1] > 0.787 + if __name__ == "__main__": train_and_evaluate(encoder_parser(None)) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index 2c5bd7025..0d74876ef 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -18,14 +18,20 @@ from flax import linen as nn from flax.training import train_state -from common import is_bf16_supported, get_fp8_recipe_from_name_string +from common import ( + is_bf16_supported, + get_quantization_recipe_from_name_string, + unpack_cached_datasets_if_available, +) import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode +unpack_cached_datasets_if_available() PARAMS_KEY = "params" DROPOUT_KEY = "dropout" +SR_KEY = "sr_rng" INPUT_KEY = "input_rng" @@ -94,6 +100,8 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): epoch_accuracy = [] for perm in perms: + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_inputs = train_ds["sentence"][perm, ...] batch_masks = train_ds["mask"][perm, ...] batch_labels = train_ds["label"][perm, ...] @@ -109,11 +117,11 @@ def train_epoch(state, train_ds, batch_size, rngs, var_collect): @jax.jit -def eval_step(state, inputs, masks, labels, var_collect): +def eval_step(state, inputs, masks, labels, var_collect, rngs): """Computes loss and accuracy for a single batch.""" def loss_fn(var_collect, disable_dropout=False): - logits = state.apply_fn(var_collect, inputs, masks, disable_dropout) + logits = state.apply_fn(var_collect, inputs, masks, disable_dropout, rngs=rngs) one_hot = jax.nn.one_hot(labels.astype(jnp.int32), 2) loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot)) return loss, logits @@ -124,7 +132,7 @@ def loss_fn(var_collect, disable_dropout=False): return loss, accuracy -def eval_model(state, test_ds, batch_size, var_collect): +def eval_model(state, test_ds, batch_size, var_collect, rngs): """Evaluation loop.""" test_ds_size = len(test_ds["sentence"]) num_steps = test_ds_size // batch_size @@ -133,11 +141,15 @@ def eval_model(state, test_ds, batch_size, var_collect): all_accuracy = [] for batch_start in range(0, valid_size, batch_size): + # Split and reassign to 'rngs' to ensure unique rng for each step + rngs = {key: jax.random.split(rngs[key])[1] for key in rngs} batch_end = batch_start + batch_size batch_inputs = test_ds["sentence"][batch_start:batch_end] batch_masks = test_ds["mask"][batch_start:batch_end] batch_labels = test_ds["label"][batch_start:batch_end] - loss, accuracy = eval_step(state, batch_inputs, batch_masks, batch_labels, var_collect) + loss, accuracy = eval_step( + state, batch_inputs, batch_masks, batch_labels, var_collect, rngs + ) all_loss.append(loss) all_accuracy.append(accuracy) @@ -197,7 +209,7 @@ def get_datasets(max_seq_len): def check_fp8(state, var_collect, inputs, masks, labels): "Check if model includes FP8." - rngs = {DROPOUT_KEY: jax.random.PRNGKey(0)} + rngs = {DROPOUT_KEY: jax.random.PRNGKey(0), SR_KEY: jax.random.PRNGKey(0)} func_jaxpr = str(jax.make_jaxpr(train_step)(state, inputs, masks, labels, var_collect, rngs)) assert "f8_e5m2" in func_jaxpr or "f8_e4m3" in func_jaxpr @@ -210,19 +222,20 @@ def train_and_evaluate(args): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + init_rngs = {PARAMS_KEY: params_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} input_shape = [args.batch_size, args.max_seq_len] mask_shape = [args.batch_size, 1, args.max_seq_len, args.max_seq_len] label_shape = [args.batch_size] if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None - with te.fp8_autocast( - enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + with te.autocast( + enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() ): encoder = Net(num_embed) # We use nn.Embed, thus inputs need to be in int @@ -240,21 +253,25 @@ def train_and_evaluate(args): if args.dry_run: labels = jnp.zeros(label_shape, dtype=jnp.bfloat16) - rngs = {DROPOUT_KEY: dropout_rng} + rngs = {DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} train_step(state, inputs, masks, labels, var_collect, rngs) print("PASSED") return None for epoch in range(1, args.epochs + 1): + # Split and reassign to 'rng' to ensure unique rng for each step rng, input_rng = jax.random.split(rng) rng, dropout_rng = jax.random.split(rng) - rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng} + rng, sr_rng = jax.random.split(rng) + rngs = {INPUT_KEY: input_rng, DROPOUT_KEY: dropout_rng, SR_KEY: sr_rng} state, train_loss, train_accuracy, var_collect = train_epoch( state, train_ds, args.batch_size, rngs, var_collect ) - test_loss, test_accuracy = eval_model(state, test_ds, args.test_batch_size, var_collect) + test_loss, test_accuracy = eval_model( + state, test_ds, args.test_batch_size, var_collect, rngs + ) print( f"Epoch: {epoch:>2} " @@ -331,8 +348,9 @@ def encoder_parser(args): class TestEncoder(unittest.TestCase): """Encoder unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) + is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) def setUp(self): """Run 3 epochs for testing""" @@ -342,7 +360,7 @@ def setUp(self): def test_te_bf16(self): """Test Transformer Engine with BF16""" actual = train_and_evaluate(self.args) - assert actual[0] < 0.45 and actual[1] > 0.79 + assert actual[0] < 0.452 and actual[1] > 0.788 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_delayed_scaling_fp8(self): @@ -350,7 +368,7 @@ def test_te_delayed_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "DelayedScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.457 and actual[1] > 0.784 @unittest.skipIf(not is_fp8_supported, fp8_reason) def test_te_current_scaling_fp8(self): @@ -358,7 +376,7 @@ def test_te_current_scaling_fp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "Float8CurrentScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.461 and actual[1] > 0.784 @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) def test_te_mxfp8(self): @@ -366,7 +384,15 @@ def test_te_mxfp8(self): self.args.use_fp8 = True self.args.fp8_recipe = "MXFP8BlockScaling" actual = train_and_evaluate(self.args) - assert actual[0] < 0.455 and actual[1] > 0.79 + assert actual[0] < 0.457 and actual[1] > 0.784 + + @unittest.skipIf(not is_nvfp4_supported, nvfp4_reason) + def test_te_nvfp4(self): + """Test Transformer Engine with NVFP4""" + self.args.use_fp8 = True + self.args.fp8_recipe = "NVFP4BlockScaling" + actual = train_and_evaluate(self.args) + assert actual[0] < 0.477 and actual[1] > 0.769 if __name__ == "__main__": diff --git a/examples/jax/mnist/README.md b/examples/jax/mnist/README.md index d67c1ac2b..f92229a25 100644 --- a/examples/jax/mnist/README.md +++ b/examples/jax/mnist/README.md @@ -6,13 +6,13 @@ This example uses MNIST training to demonstrate the Transformer Engine usage. Th 2. Define model: The `Net` class is a small CNN model for image classification. It has an option to switch between using `nn.Dense` provided by Flax and `te.DenseGeneral` provided by the Transformer Engine. This allows for easy comparison between the two libraries. -3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.fp8_autocast` context manager. If fp8_autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If fp8_autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under fp8_autocast. If not, then fp8_autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword. +3. Build training loop: The `train_and_evaluate` is the main routine to initialize the model and start training and evaluating. For FP8 training, the key is `te.autocast` context manager. If autocast is enabled, it will cast all `te.DenseGeneral` to FP8 precision. The `var_collect` is a collection including needed information for model training, such as parameters and FP8 metadata, which is necessary for correct casting of BF16 tensors into FP8 tensors at runtime. If autocast is turned on and print var_collect, you will see FP8 metadata inside, such as `fp8_meta_collection` section. The training and evaluating with FP8 have to be done under autocast. If not, then autocast will deconstruct the FP8 metadata, and the model will fall back to higher floating point precision, such as BF16 in this example. To check if FP8 is enabled, use the `check_fp8` routine. If model initialization with FP8 works fine, the string returned by jax.make_jaxpr should include the `Float8` keyword. 4. Training process: In `apply_model`, the main difference between normal Flax usage and this example is, with FP8 training, the FP8 metadata has to be filled into the gradient function `grad_fn`. Otherwise, the Transformer Engine doesn't know how to cast the BF16 tensor into FP8 tensor at runtime correctly. The FP8 metadata doesn't belong in model parameters (`state.params`), so we need to manually combine the metadata and latest model parameters into var_collect as a frozen dictionary and fill it to the gradient function. 5. Evaluating process: The evaluating process is the same as the training process. Need to ensure FP8 metadata is inside var_collect and fill it into loss function. -6. Additional options: The `te.fp8_autocast` context manager has additional options +6. Additional options: The `te.autocast` context manager has additional options * FP8 Recipe: control FP8 training behavior. See the [FP8 tutorial](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html) for a detailed explanation of FP8 recipes and the supported options. ## Run ## diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 92baf4b0c..62f7954e0 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -18,11 +18,17 @@ import transformer_engine.jax as te import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode +from transformer_engine.jax.quantize import is_scaling_mode_supported, ScalingMode DIR = str(Path(__file__).resolve().parents[1]) sys.path.append(str(DIR)) -from encoder.common import is_bf16_supported, get_fp8_recipe_from_name_string +from encoder.common import ( + is_bf16_supported, + get_quantization_recipe_from_name_string, + unpack_cached_datasets_if_available, +) + +unpack_cached_datasets_if_available() IMAGE_H = 28 IMAGE_W = 28 @@ -189,12 +195,12 @@ def train_and_evaluate(args): label_shape = [args.batch_size] if args.use_fp8: - fp8_recipe = get_fp8_recipe_from_name_string(args.fp8_recipe) + fp8_recipe = get_quantization_recipe_from_name_string(args.fp8_recipe) else: fp8_recipe = None - with te.fp8_autocast( - enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + with te.autocast( + enabled=args.use_fp8, recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() ): cnn = Net(args.use_te) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) @@ -308,8 +314,8 @@ def mnist_parser(args): class TestMNIST(unittest.TestCase): """MNIST unittests""" - is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING) - is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) + is_fp8_supported, fp8_reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) + is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) @classmethod def setUpClass(cls): diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index d52e97d65..1fd40305c 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -68,7 +68,7 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("--seed", type=int, default=1234, help="RNG seed.") parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." ) parser.add_argument( "--no-comm-overlap", @@ -299,7 +299,7 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) dist_print(" |-- Forward pass", group=tp_group, debug=True) with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): + with te.autocast(enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world): y = model(x) if isinstance(y, tuple): out, *_ = y diff --git a/examples/pytorch/fsdp/README.md b/examples/pytorch/fsdp/README.md index 78c4d6b85..f9a49af8d 100644 --- a/examples/pytorch/fsdp/README.md +++ b/examples/pytorch/fsdp/README.md @@ -49,5 +49,5 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd # ... ``` -**NOTE:** This example has `fp8_autocast()` enabled by default. To run on GPUs without Fp8 support +**NOTE:** This example has `autocast()` enabled by default. To run on GPUs without Fp8 support (e.g.: A100), add the `--no-fp8` option to the commands shown above. diff --git a/examples/pytorch/fsdp/fsdp.py b/examples/pytorch/fsdp/fsdp.py index 622228536..789389757 100644 --- a/examples/pytorch/fsdp/fsdp.py +++ b/examples/pytorch/fsdp/fsdp.py @@ -173,7 +173,7 @@ def parse_fsdp_args(): "--no-fp8", action="store_true", default=False, - help="Disables the te.fp8_autocast() context.", + help="Disables the te.autocast() context.", ) parser.add_argument( "--no-defer-init", @@ -284,11 +284,11 @@ def train(opts): dtype=opts.dtype, device="cuda", ) - # fp8_autocast needs to be given the FSDP process group for amax reductions - with te.fp8_autocast(enabled=not opts.no_fp8, fp8_recipe=fp8_recipe, fp8_group=all_gpus): + # autocast needs to be given the FSDP process group for amax reductions + with te.autocast(enabled=not opts.no_fp8, recipe=fp8_recipe, amax_reduction_group=all_gpus): y = te_model(x) loss = y.sum() - # calculate gradient and take training step outside the fp8_autocast context + # calculate gradient and take training step outside the autocast context loss.backward() optim.step() optim.zero_grad(set_to_none=True) diff --git a/examples/pytorch/mnist/main.py b/examples/pytorch/mnist/main.py index 1fd1c90e7..347d36e7c 100644 --- a/examples/pytorch/mnist/main.py +++ b/examples/pytorch/mnist/main.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -54,12 +54,8 @@ def train(args, model, device, train_loader, optimizer, epoch, use_amp, use_fp8) for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) optimizer.zero_grad() - if use_amp: - with autocast(device_type='cuda', dtype=torch.float16): - output = model(data) - else: - with te.fp8_autocast(enabled=use_fp8): - output = model(data) + with te.autocast(enabled=use_fp8): + output = model(data) loss = F.nll_loss(output, target) loss.backward() optimizer.step() @@ -82,7 +78,7 @@ def calibrate(model, device, test_loader, fp8): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) - with te.fp8_autocast(enabled=fp8, calibrating=True): + with te.autocast(enabled=fp8, calibrating=True): output = model(data) @@ -94,12 +90,8 @@ def test(model, device, test_loader, use_amp, use_fp8): with torch.no_grad(): for data, target in test_loader: data, target = data.to(device), target.to(device) - if use_amp: - with autocast(device_type='cuda', dtype=torch.float16): - output = model(data) - else: - with te.fp8_autocast(enabled=use_fp8): - output = model(data) + with te.autocast(enabled=use_fp8): + output = model(data) test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 97824bbdb..812ea384d 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -6,7 +6,8 @@ "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", "" : "", - "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)" + "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)", + "__nv_bfloat162":"__hip_bfloat162" } } diff --git a/pyproject.toml b/pyproject.toml index c4df4aecc..3814aabd0 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,11 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. [build-system] -requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax", "flax>=0.7.1"] +requires = ["setuptools>=61.0", "cmake>=3.21", "wheel", "pybind11[global]", "ninja", "pip", "torch>=2.1", "jax>=0.5.0", "flax>=0.7.1"] # Use legacy backend to import local packages in setup.py build-backend = "setuptools.build_meta:__legacy__" - diff --git a/qa/L0_jax_distributed_unittest/test.sh b/qa/L0_jax_distributed_unittest/test.sh index d9c46347f..ae45f398e 100644 --- a/qa/L0_jax_distributed_unittest/test.sh +++ b/qa/L0_jax_distributed_unittest/test.sh @@ -29,6 +29,10 @@ wait python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" +wait + +TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/collective_gemm/run_test_cgemm.sh || test_fail "run_test_cgemm.sh" +wait if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index e4a3f4630..cb097d492 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder_without_custom_call.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index b4bf0a024..b6c42109b 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -2,35 +2,51 @@ # # See LICENSE for license information. +function error_exit() { + echo "Error: $1" + exit 1 +} +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} : ${NVTE_TEST_NVINSPECT_CONFIGS_DIR:=$TE_PATH/tests/pytorch/debug/test_configs/} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" # Config with the dummy feature which prevents nvinspect from being disabled. # Nvinspect will be disabled if no feature is active. : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} -FAIL=0 - # It is not installed as a requirement, # because it is not available on PyPI. pip uninstall -y nvdlfw-inspect pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git -pip install pytest==8.2.1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pip install pytest==8.2.1 || error_exit "Failed to install pytest" +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 - -exit $FAIL +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 394273ca4..e1ce68009 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -27,10 +27,11 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" @@ -41,7 +42,8 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +NVTE_FLASH_ATTN=0 NVTE_CPU_OFFLOAD_V1=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading_v1.xml $TE_PATH/tests/pytorch/test_cpu_offloading_v1.py || test_fail "test_cpu_offloading_v1.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 8ecc5a917..886f27747 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -2,11 +2,44 @@ # # See LICENSE for license information. -set -xe +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* -SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh +export NVTE_JAX_UNITTEST_LEVEL="L1" + +# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate. +export XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_dense.xml $TE_PATH/tests/jax/test_distributed_dense.py || test_fail "test_distributed_dense.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_helper.xml $TE_PATH/tests/jax/test_distributed_helper.py || test_fail "test_distributed_helper.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_layernorm.xml $TE_PATH/tests/jax/test_distributed_layernorm.py || test_fail "test_distributed_layernorm.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_mlp.xml $TE_PATH/tests/jax/test_distributed_layernorm_mlp.py || test_fail "test_distributed_layernorm_mlp.py" + +# XLA_FLAGS to WAR for test_distributed_softmax issue with NCCL +# TODO(Kshitij): remove when NCCL issue is fixed +XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_softmax.xml $TE_PATH/tests/jax/test_distributed_softmax.py || test_fail "test_distributed_softmax.py" + +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" + +# TODO(Phuong): add this test back after it is verified +# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" + +if [ $RET -ne 0 ]; then + echo "Error: some sub-tests failed: $FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 7f061d222..e698e997a 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -30,6 +30,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" @@ -47,9 +48,9 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_ : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} -pytest -v -s $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" +pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_distributed.xml $TE_PATH/tests/pytorch/debug/test_distributed.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "debug test_distributed.py" # standard numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_2.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh index 1486d5097..7fce13a3d 100644 --- a/qa/L1_pytorch_onnx_unittest/test.sh +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -3,9 +3,11 @@ # See LICENSE for license information. -pip3 install onnxruntime==1.20.1 -pip3 install onnxruntime_extensions==0.13.0 +pip3 install onnxruntime +pip3 install onnxruntime_extensions : ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" -python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/qa/L2_jax_distributed_unittest/test.sh b/qa/L2_jax_distributed_unittest/test.sh index 0b7372650..de5624a59 100644 --- a/qa/L2_jax_distributed_unittest/test.sh +++ b/qa/L2_jax_distributed_unittest/test.sh @@ -8,4 +8,5 @@ set -xe : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate. +XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 7e9616cd0..418e824c1 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -18,10 +18,10 @@ sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); pri export FLASH_ATTN_CUDA_ARCHS=$sm_arch if [ $sm_arch -gt 90 ] then - FA_versions=(2.8.1) + FA_versions=(2.8.3) elif [ $sm_arch -eq 90 ] then - FA_versions=(2.7.3 2.8.1 3.0.0b1) + FA_versions=(2.7.3 2.8.3 3.0.0b1) fi for fa_version in "${FA_versions[@]}" diff --git a/setup.py b/setup.py index 1ae476311..eb241f5cb 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,11 @@ from importlib import metadata import os +import shutil +import subprocess import time from pathlib import Path from typing import List, Tuple -import subprocess import setuptools from setuptools.command.egg_info import egg_info @@ -27,6 +28,7 @@ cuda_version, get_frameworks, remove_dups, + min_python_version_str, ) frameworks = get_frameworks() @@ -171,9 +173,64 @@ def setup_requirements() -> Tuple[List[str], List[str]]: return [remove_dups(reqs) for reqs in [install_reqs, test_reqs]] +def git_check_submodules() -> None: + """ + Attempt to checkout git submodules automatically during setup. + + This runs successfully only if the submodules are + either in the correct or uninitialized state. + + Note to devs: With this, any updates to the submodules itself, e.g. moving to a newer + commit, must be commited before build. This also ensures that stale submodules aren't + being silently used by developers. + """ + + # Provide an option to skip these checks for development. + if bool(int(os.getenv("NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD", "0"))): + return + + # Require git executable. + if shutil.which("git") is None: + return + + # Require a .gitmodules file. + if not (current_file_path / ".gitmodules").exists(): + return + + try: + submodules = subprocess.check_output( + ["git", "submodule", "status", "--recursive"], + cwd=str(current_file_path), + text=True, + ).splitlines() + + for submodule in submodules: + # '-' start is for an uninitialized submodule. + # ' ' start is for a submodule on the correct commit. + assert submodule[0] in ( + " ", + "-", + ), ( + "Submodules are initialized incorrectly. If this is intended, set the " + "environment variable `NVTE_SKIP_SUBMODULE_CHECKS_DURING_BUILD` to a " + "non-zero value to skip these checks during development. Otherwise, " + "run `git submodule update --init --recursive` to checkout the correct" + " submodule commits." + ) + + subprocess.check_call( + ["git", "submodule", "update", "--init", "--recursive"], + cwd=str(current_file_path), + ) + except subprocess.CalledProcessError: + return + + if __name__ == "__main__": __version__ = te_version() + git_check_submodules() + with open("README.rst", encoding="utf-8") as f: long_description = f.read() @@ -182,13 +239,20 @@ def setup_requirements() -> Tuple[List[str], List[str]]: assert bool( int(os.getenv("NVTE_RELEASE_BUILD", "0")) ), "NVTE_RELEASE_BUILD env must be set for metapackage build." - te_cuda_vers = "rocm" if rocm_build() else "cu12" + te_cuda_vers = "cu12" ext_modules = [] cmdclass = {} package_data = {} include_package_data = False - install_requires = ([f"transformer_engine_{te_cuda_vers}=={__version__}"],) + install_requires = [] extras_require = { + "core": [f"transformer_engine_cu12=={__version__}"], + "core_cu12": [f"transformer_engine_cu12=={__version__}"], + "core_cu13": [f"transformer_engine_cu13=={__version__}"], + "pytorch": [f"transformer_engine_torch=={__version__}"], + "jax": [f"transformer_engine_jax=={__version__}"], + } if not rocm_build() else { + "core": [f"transformer_engine_rocm=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"], } @@ -239,7 +303,7 @@ def setup_requirements() -> Tuple[List[str], List[str]]: long_description_content_type="text/x-rst", ext_modules=ext_modules, cmdclass={"egg_info": HipifyMeta, "build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, - python_requires=">=3.8", + python_requires=f">={min_python_version_str()}", classifiers=["Programming Language :: Python :: 3"], install_requires=install_requires, license_files=("LICENSE",), diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index 46bcf4242..d3b75bbbf 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -23,16 +23,18 @@ list(APPEND test_cuda_sources test_act.cu test_normalization.cu test_normalization_mxfp8.cu + test_memset.cu test_multi_cast_transpose.cu test_multi_padding.cu test_multi_unpadding.cu test_causal_softmax.cu - test_swizzle.cu test_swap_first_dims.cu ../test_common.cu) if(USE_CUDA) list(APPEND test_cuda_sources - test_cast_float8blockwise.cu) + test_cast_nvfp4_transpose.cu + test_cast_float8blockwise.cu + test_swizzle.cu) else() list(APPEND test_cuda_sources test_cublaslt_gemm.cu) @@ -70,6 +72,7 @@ else() add_executable(test_operator ${test_hip_sources}) endif() +# Find required packages find_package(OpenMP REQUIRED) if(USE_CUDA) list(APPEND test_operator_LINKER_LIBS CUDA::cudart GTest::gtest_main ${TE_LIB} CUDA::nvrtc CUDNN::cudnn) diff --git a/tests/cpp/operator/test_cast_float8blockwise.cu b/tests/cpp/operator/test_cast_float8blockwise.cu index cf68625ed..c9c98edc3 100644 --- a/tests/cpp/operator/test_cast_float8blockwise.cu +++ b/tests/cpp/operator/test_cast_float8blockwise.cu @@ -503,6 +503,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 2u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. @@ -556,6 +562,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) { q_opts.amax_epsilon = eps; q_opts.block_scaling_dim = 1u; + // On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, + // which requires using power of two scaling factors. Skip unsupported tests. + if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) { + GTEST_SKIP(); + } + if (colwise && matrix_size.size() < 2) { // test_common Tensor initialization code does not // handle this case. diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index b635dc00b..a029e4f3f 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -86,6 +86,7 @@ void compute_ref(const ProcessingMethod processing_method, // Cache computations for (size_t i = i_min; i < i_max; ++i) { for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); @@ -314,19 +315,22 @@ void performTest_x1(const ProcessingMethod processing_method, #ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices; #else const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; #endif - std::vector mismatches_scales_indices; size_t mismatches_scales = 0; - compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; @@ -499,30 +503,36 @@ void performTest_x2(const ProcessingMethod processing_method, #ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices_rowwise; + std::vector mismatches_scales_indices_colwise; #else const double abs_tolerable_mismatches_limit = 0.0; const double rel_tolerable_mismatches_limit = 0.0; #endif - std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif //#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); - std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif //#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 52180786d..ba4144a7c 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -264,7 +264,9 @@ void performTest_x1(const size_t rows, rowwise, colwise); +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices; +#endif size_t mismatches_scales = 0; const size_t scale_diff_abs_tolerance = 0; const double abs_tolerable_mismatches_limit = 1.0; @@ -274,21 +276,25 @@ void performTest_x1(const size_t rows, ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { - compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { - compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales_indices, - mismatches_scales, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } #ifdef __HIP_PLATFORM_AMD__ @@ -394,24 +400,34 @@ void performTest_x2(const size_t rows, const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices_rowwise; +#endif size_t mismatches_scales_rowwise = 0; - compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), - ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_indices_rowwise, mismatches_scales_rowwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, + unpadded_blocks_X_rowwise, scales_stride_rowwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_rowwise, +#endif + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ std::vector mismatches_scales_indices_colwise; +#endif size_t mismatches_scales_colwise = 0; - compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_indices_colwise, mismatches_scales_colwise, - scale_diff_abs_tolerance, - abs_tolerable_mismatches_limit, - rel_tolerable_mismatches_limit); + compare_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), unpadded_blocks_Y_colwise, + unpadded_blocks_X_colwise, scales_stride_colwise, +#ifdef __HIP_PLATFORM_AMD__ + mismatches_scales_indices_colwise, +#endif + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); #ifdef __HIP_PLATFORM_AMD__ if (::testing::Test::HasFatalFailure()) return; @@ -478,7 +494,7 @@ class CastMXFP8_GatedActTestSuite : public ::testing::TestWithParam bool>> {}; TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { - #ifdef __HIP_PLATFORM_AMD__ +#ifdef __HIP_PLATFORM_AMD__ omp_set_num_threads(std::min(128, omp_get_max_threads())); // Using threads = # of vcpus causes occasional errors. #else // #ifdef __HIP_PLATFORM_AMD__ // Skip tests for pre-Blackwell architectures diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu new file mode 100644 index 000000000..afd7927da --- /dev/null +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -0,0 +1,736 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include + +#include +#include +#include "../test_common.h" +#include "transformer_engine/transformer_engine.h" +#include + +using namespace transformer_engine; +using namespace test; + +namespace { + +enum ActivationType { + Identity, + GeLU, + SiLU, + ReLU, + QGeLU, + SReLU +}; + +double2 cvt_fp4x2_to_double2(fp4e2m1x2 fp4_pair) { + const __half2_raw raw_truncated_to_fp4e2m1_pair = + __nv_cvt_fp4x2_to_halfraw2(*reinterpret_cast<__nv_fp4x2_storage_t*>(&fp4_pair), __NV_E2M1); + + const __half2 truncated_to_fp4e2m1_pair(raw_truncated_to_fp4e2m1_pair); + const double truncated_to_fp4e2m1_x = static_cast(truncated_to_fp4e2m1_pair.x); + const double truncated_to_fp4e2m1_y = static_cast(truncated_to_fp4e2m1_pair.y); + return {truncated_to_fp4e2m1_x, truncated_to_fp4e2m1_y}; +} + +template +std::vector create_transpose(const InputType* const input, const size_t rows, size_t cols) { + std::vector input_t(cols * rows); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + const size_t idx_t = j * rows + i; + input_t[idx_t] = input[idx]; + } + } + return input_t; +} + +// Compute the global encode scale factor for a given global amax +float compute_global_encode_scaling_factor_FP4(const float global_amax) { + constexpr float fp8_max = 448.0f; // 448.0f; + constexpr float fp4_max = 6.0f; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, Numeric_Traits::maxNorm); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +// 1D Scaling: Original implementation with 1x16 blocks +template +void quantize_nvfp4_1d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + + constexpr size_t block_size_X = 16; + const size_t blocks_X = divide_round_up(cols, block_size_X); + + std::array cache_buffer; + for (size_t i = 0; i < block_size_X; ++i) { + cache_buffer[i] = 0.0f; + } + + for (size_t i = 0; i < rows; ++i) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t j_min = block_X * block_size_X; + const size_t j_max = j_min + block_size_X; + + // Find block amax + float block_amax = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx] = elt; + block_amax = std::max(block_amax, std::abs(elt)); + } + + // 2. Compute E4M3 scaling factor + // Compute per-block encoding/decoding scaling factor + const float S_dec_b = block_amax / 6.0f; + + // Scale & Store per-block decoding scaling factor + const float S_dec_b_fp8 = S_dec_b * S_enc; + + // Compute "correct" per-block encoding scaling factor + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = static_cast(S_dec_b_fp8); + const float scale_reciprocal = S_enc_b_fp8; + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + + // const double2 truncated_pair = cvt_fp4x2_to_double2(casted_to_e2m1_pair); + } + } + } +} + +// Compute 2D mathematical scaling factors (8x8 for 128x128 input) +template +void compute_2d_mathematical_scales(float (*OP)(const float), + const InputType* const input, + const size_t rows, + const size_t cols, + const float global_amax, + std::vector>& math_scales) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + math_scales.resize(blocks_Y, std::vector(blocks_X)); + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Find 2D block amax over entire 16x16 region + float block_amax = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + block_amax = std::max(block_amax, std::abs(elt)); + } + } + + // Compute E4M3 scaling factor for this 16x16 block + const float S_dec_b = block_amax / 6.0f; + const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + math_scales[block_Y][block_X] = S_dec_b_fp8; + } + } +} + +// 2D Scaling: NEW implementation with proper replication +template +void quantize_nvfp4_2d(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax) { + + // Step 1: Compute mathematical 8x8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax); + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Replicate scaling factors row-wise (128×8 storage) - only if scales is not nullptr + if (scales != nullptr) { + // Each of the 128 rows gets scaling factors from its corresponding 16×16 block + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + } + + // Step 3: Apply quantization using the mathematical scaling factors + std::array, block_size_Y> cache_buffer; + + for (size_t block_Y = 0; block_Y < blocks_Y; ++block_Y) { + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t i_min = block_Y * block_size_Y; + const size_t i_max = std::min(i_min + block_size_Y, rows); + const size_t j_min = block_X * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + + // Get the scaling factor for this block + const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); + const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float scale_reciprocal = S_enc_b_fp8; + + // Process and cache data for this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const size_t idx = i * cols + j; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x = j - j_min; + + const float input_elt = static_cast(input[idx]); + const float act_elt = OP(input_elt); + const float elt = static_cast(static_cast(act_elt)); + cache_buffer[cache_idx_y][cache_idx_x] = elt; + } + } + + // Apply scaling to all elements in this 16x16 block + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const size_t cache_idx_y = i - i_min; + const size_t cache_idx_x1 = j - j_min; + const size_t cache_idx_x2 = std::min(cache_idx_x1 + 1, block_size_X - 1); + + const float cached_x = cache_buffer[cache_idx_y][cache_idx_x1]; + const float cached_y = ((j + 1) < j_max && cache_idx_x2 < block_size_X) ? + cache_buffer[cache_idx_y][cache_idx_x2] : 0.0f; + + const float scaled_elt_x = cached_x * scale_reciprocal; + const float scaled_elt_y = cached_y * scale_reciprocal; + const float2 scaled_elt_pair = {scaled_elt_x, scaled_elt_y}; + + fp4e2m1x2 casted_to_e2m1_pair(scaled_elt_pair); + output[idx_pair] = casted_to_e2m1_pair; + } + } + } + } +} + +// Wrapper function that calls appropriate implementation based on 2D flag +template +void quantize_nvfp4(float (*OP)(const float), + const InputType* const input, + fp4e2m1x2* const output, + fp8e4m3* const scales, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const float global_amax, + const bool use_2d_quantization = false) { + if (use_2d_quantization) { + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } else { + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax); + } +} + +template +void compute_ref(float (*OP)(const float), + const InputType* input, + fp4e2m1x2* output, + fp4e2m1x2* output_t, + fp8e4m3* scales, + fp8e4m3* scales_t, + const float global_amax, + const size_t rows, + const size_t cols, + const size_t scales_stride, + const size_t scales_stride_t, + const bool use_2d_quantization = false) +{ + std::vector input_t = create_transpose(input, rows, cols); + + if (use_2d_quantization) { + // Step 1: Compute mathematical 8×8 scaling factors + std::vector> math_scales; + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales); + + constexpr size_t block_size_Y = 16; + constexpr size_t block_size_X = 16; + const size_t blocks_Y = divide_round_up(rows, block_size_Y); + const size_t blocks_X = divide_round_up(cols, block_size_X); + + // Step 2: Generate scales (128×8) by replicating row-wise + for (size_t i = 0; i < rows; ++i) { + const size_t block_Y = i / block_size_Y; + for (size_t block_X = 0; block_X < blocks_X; ++block_X) { + const size_t scale_idx = i * scales_stride + block_X; + scales[scale_idx] = math_scales[block_Y][block_X]; + } + } + + // Step 3: Generate scales_t (128×8) with proper transposed block mapping + for (size_t i = 0; i < cols; ++i) { // cols = 128, which becomes rows of transposed data + const size_t block_X_orig = i / block_size_X; // i was column index in original, so maps to block_X + for (size_t block_Y_new = 0; block_Y_new < blocks_Y; ++block_Y_new) { // block in transposed coordinate + const size_t scale_idx = i * scales_stride_t + block_Y_new; + scales_t[scale_idx] = math_scales[block_Y_new][block_X_orig]; + } + } + + // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d + // (This part processes the actual FP4 data using the mathematical scaling factors) + quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, global_amax); // scales already filled + quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, global_amax); // scales_t already filled + + } else { + quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_2d_quantization); + quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, global_amax, use_2d_quantization); + } +} + +void compare_nvfp4_tensors(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8) { + std::vector mismatch_messages; + size_t total_mismatches = 0; + + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = false; + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + total_mismatches++; + std::string msg = "Mismatch at place (" + std::to_string(idx + k) + "): " + + std::to_string(t) + " vs " + std::to_string(r) + + " (abs_diff: " + std::to_string(fabs(t - r)) + + ", rel_diff: " + std::to_string(r == 0 ? 0.0 : fabs((t - r) / r)) + ")"; + mismatch_messages.push_back(msg); + + // Optional: limit number of detailed messages to avoid overwhelming output + if (mismatch_messages.size() <= 100) { + std::cout << "Error in tensor " << name << ": " << msg << std::endl; + } + } + } + } + } + + // Always report summary - either success or failure + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total elements checked: " << (rows * cols) << std::endl; + + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatches found: " << total_mismatches << std::endl; + std::cout << "Mismatch rate: " << (100.0 * total_mismatches) / (rows * cols) << "%" << std::endl; + if (mismatch_messages.size() > 100) { + std::cout << "... and " << (mismatch_messages.size() - 100) << " more mismatches (showing first 100)" << std::endl; + } + std::cout << "============================" << std::endl; + + GTEST_FAIL() << "Found " << total_mismatches << " mismatches in tensor " << name; + } else { + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "All elements match within tolerance!" << std::endl; + std::cout << "Tensor " << name << " is IDENTICAL to reference" << std::endl; + std::cout << "============================" << std::endl; + } +} + +// Optional: Function to dump tensor data to files for detailed analysis +void dump_nvfp4_tensor_data(const std::string& prefix, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + std::string test_file = prefix + "_test.txt"; + std::string ref_file = prefix + "_ref.txt"; + std::string diff_file = prefix + "_diff.txt"; + + std::ofstream test_out(test_file); + std::ofstream ref_out(ref_file); + std::ofstream diff_out(diff_file); + + if (test_out.is_open() && ref_out.is_open() && diff_out.is_open()) { + for (int i = 0; i < rows; ++i) { + for (int j = 0; j < cols; j += 2) { + const int idx = i * cols + j; + double2 test_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[idx/2])); + double2 ref_data_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[idx/2])); + + for (int k = 0; k < 2; ++k) { + const double t = (k == 0 ? test_data_pair.x : test_data_pair.y); + const double r = (k == 0 ? ref_data_pair.x : ref_data_pair.y); + const int pos = idx + k; + + test_out << "pos[" << pos << "] = " << t << std::endl; + ref_out << "pos[" << pos << "] = " << r << std::endl; + diff_out << "pos[" << pos << "] test=" << t << " ref=" << r + << " abs_diff=" << fabs(t - r) + << " rel_diff=" << (r == 0 ? 0.0 : fabs((t - r) / r)) << std::endl; + } + } + } + std::cout << "DEBUG: Dumped tensor data to files: " << test_file << ", " << ref_file << ", " << diff_file << std::endl; + } else { + std::cout << "WARNING: Could not open files for tensor data dump" << std::endl; + } +} + +void print_detailed_tensor_comparison(const std::string& name, + const fp4e2m1 *test_data, const fp4e2m1 *ref_data, + const int rows, const int cols) { + printf("\n=== DETAILED COMPARISON for %s (%d×%d = %d elements) ===\n", + name.c_str(), rows, cols, rows * cols); + + const int total_elements = rows * cols; + const int check_count = 128; + + printf("--- FIRST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = 0; i < std::min(check_count, total_elements); ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + + if (total_elements > 2 * check_count) { + printf("\n--- LAST %d ELEMENTS ---\n", check_count); + printf("Index | Test_Value | Ref_Value | Match\n"); + printf("------|---------------|---------------|-------\n"); + for (int i = total_elements - check_count; i < total_elements; ++i) { + double2 test_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&test_data[i/2])); + double2 ref_pair = cvt_fp4x2_to_double2(*reinterpret_cast(&ref_data[i/2])); + + double t = (i % 2 == 0) ? test_pair.x : test_pair.y; + double r = (i % 2 == 0) ? ref_pair.x : ref_pair.y; + bool match = (fabs(t - r) < 1e-6); + + printf("%5d | %13.6f | %13.6f | %s\n", i, t, r, match ? "✓" : "✗"); + } + } + printf("==================================\n"); +} + +void compareResults_nvfp4(const Tensor &test, + const void *ref, const void *ref_t, const int rows, const int cols, + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + if (if_on_gpus) test.to_cpu(); + + const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data = reinterpret_cast(ref); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + + // Print detailed element-by-element comparison + // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); + // print_detailed_tensor_comparison("output_t", test_data_t, ref_data_t, cols, rows); + + // Optionally dump tensor data to files for detailed analysis + if (dump_data) { + dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + + compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); +} + +template +void performTest(float (*OP)(const float), + const std::vector& shape) { + using namespace test; + + DType itype = TypeInfo::dtype; + DType otype = DType::kFloat4E2M1; + + const size_t rows = first_dimension(shape); + const size_t cols = last_dimension(shape); + + // Use get_scale_tensor_dims for NVFP4 scale tensor dimensions + // Now that CheckScaleTensorShape is fixed, this should work correctly + const std::array scale_dims = get_scale_tensor_dims(rows, cols, 1, 16); + const std::array scale_dims_t = get_scale_tensor_dims(cols, rows, 1, 16); + + const size_t unpadded_blocks_Y = scale_dims[0]; + const size_t unpadded_blocks_X = scale_dims[1]; + const size_t blocks_Y = scale_dims[2]; + const size_t blocks_X = scale_dims[3]; + const size_t scales_stride = blocks_X; + + const size_t unpadded_blocks_Y_t = scale_dims_t[0]; + const size_t unpadded_blocks_X_t = scale_dims_t[1]; + const size_t blocks_Y_t = scale_dims_t[2]; + const size_t blocks_X_t = scale_dims_t[3]; + const size_t scales_stride_t = blocks_X_t; + + Tensor input("input", shape, itype); + Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + + std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); + std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); + std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); + std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + + fillCase(&input, InputsFillCase::uniform); + + // Find global amax + float amax = 0.0f; + const InputType* input_dptr = input.rowwise_cpu_dptr(); + for (size_t i = 0; i < rows; ++i) { + for (size_t j = 0; j < cols; ++j) { + const size_t idx = i * cols + j; + amax = fmaxf(amax, static_cast(input_dptr[idx])); + } + } + // Set 2nd stage NVFP4 scaling factor + output.set_scale(amax); + + bool use_2d_quantization = false; + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + output.scale(), + rows, + cols, + scales_stride, + scales_stride_t, + use_2d_quantization); + + QuantizationConfigWrapper quant_config; + + // Initialize stochastic rounding + Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); + rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed + rng_state.rowwise_cpu_dptr()[1] = 321; // rng_sequence + rng_state.from_cpu(); + quant_config.set_stochastic_rounding(false); + quant_config.set_rng_state(rng_state.data()); + + // Set 2D quantization based on compile-time flag + quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + + // Call appropriate function based on operation type + // Activation functions take 3 parameters (input, output, stream) + // nvte_quantize_v2 takes 4 parameters (input, output, quant_config, stream) + if (OP == &gelu) { + nvte_gelu(input.data(), output.data(), 0); + } else if (OP == &silu) { + nvte_silu(input.data(), output.data(), 0); + } else if (OP == &relu) { + nvte_relu(input.data(), output.data(), 0); + } else if (OP == &qgelu) { + nvte_qgelu(input.data(), output.data(), 0); + } else if (OP == &srelu) { + nvte_srelu(input.data(), output.data(), 0); + } else { + nvte_quantize_v2(input.data(), output.data(), quant_config, 0); + } + + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("DEBUG: CUDA error detected: %s\n", cudaGetErrorString(err)); + } + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + const double atol = 0.05; + const double rtol = 0.1; + + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); + + const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_ptr = ref_scales.get(); + const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); + const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + scale_mismatches_num); + + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); +} + +std::vector> tensor_dims = { + {32, 32}, + {32, 64}, + {64, 32}, + {64, 96}, + {128, 128}, + {256, 256}, + {512, 512}, + {1024, 1024}, + {2048, 2048}, + {128, 256}, + {8192, 128}, + {2048, 160}, + {8, 32, 1024}, + {16, 8, 4, 512}, + {1024, 16384}, + {4096, 13312}, +}; + +// Only the Identity activation is currently supported. +std::vector Activation_types = { + ActivationType::Identity +}; + +} // namespace + +class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam + , + transformer_engine::DType>> {}; + +TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { + // Skip tests for pre-Blackwell architectures + if (getDeviceComputeCapability() < blackwellComputeCapability) { + GTEST_SKIP(); + } + + using namespace transformer_engine; + using namespace test; + + const ActivationType Act_type = std::get<0>(GetParam()); + const auto tensor_dims = std::get<1>(GetParam()); + const DType input_type = std::get<2>(GetParam()); + + // Skip tests if the input tensor is 1D + if (tensor_dims.size() < 2) { + GTEST_SKIP(); + } + + // Forward activations + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + performTest(OP, tensor_dims); + ); +} + +std::string to_string(const ActivationType Act_type) { + switch (Act_type) { + case ActivationType::Identity: return "CAST_ONLY"; + case ActivationType::GeLU: return "GeLU"; + case ActivationType::SiLU: return "SiLU"; + case ActivationType::ReLU: return "ReLU"; + case ActivationType::QGeLU: return "QGeLU"; + case ActivationType::SReLU: return "SReLU"; + default: return ""; + } +} + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), + ::testing::ValuesIn(tensor_dims), + ::testing::Values(DType::kBFloat16)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for ( const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + return name; + }); diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index 674e09c8e..7e4687895 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -123,8 +123,18 @@ void compute_ref_output(NormType norm_type, tmp = current * rsigma[i] * g; } + // Write output (scaled only for fp8 paths) output[i * H + j] = static_cast(tmp * scale); - current_max = fmaxf(current_max, fabsf(tmp)); + + // amax semantics: + // - fp8_out (scale != 1): amax on pre-scale compute value 'tmp' + // - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16) + if (scale != 1.f) { + current_max = fmaxf(current_max, fabsf(tmp)); + } else { + OutputType out_t_val = static_cast(tmp); + current_max = fmaxf(current_max, fabsf(static_cast(out_t_val))); + } } } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 9f926d07b..5427bc118 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -109,6 +109,10 @@ size_t DIVUP(const size_t &x, const size_t &y){ return (((x) + ((y)-1)) / (y)); } +size_t DIVUP_TO_MULTIPLE(const size_t &x, const size_t &y){ + return DIVUP(x, y) * y; +} + struct scale_inv_meta { std::vector shape; DType type; @@ -145,21 +149,71 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; - auto block_alignment = std::vector{128ul, 4ul}; - { - auto alignment = block_alignment[0]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(1)), alignment) * alignment; - alignment = block_alignment[1]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(32)), alignment) * alignment; - ret_rowwise.shape = {scale_dim_0, scale_dim_1}; + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + + ret_rowwise.type = DType::kFloat8E8M0; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + ret_colwise.type = DType::kFloat8E8M0; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } - { - auto alignment = block_alignment[1]; - auto scale_dim_0 = DIVUP(DIVUP(first_dim, static_cast(32)), alignment) * alignment; - alignment = block_alignment[0]; - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(1)), alignment) * alignment; - ret_colwise.shape = {scale_dim_0, scale_dim_1}; + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + NVTE_CHECK(last_dim % 32 == 0); + NVTE_CHECK(first_dim % 32 == 0); + + scale_inv_meta ret_rowwise, ret_colwise; + + size_t scale_dim_Y = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X = DIVUP_TO_MULTIPLE(DIVUP(last_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y, scale_dim_X}; + + size_t scale_dim_Y_t = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_t = DIVUP_TO_MULTIPLE(DIVUP(first_dim, 16lu), scale_tensor_alignment_X_rowwise); + ret_colwise.shape = {scale_dim_Y_t, scale_dim_X_t}; + + ret_rowwise.type = DType::kFloat8E4M3; + ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + ret_colwise.type = DType::kFloat8E4M3; + ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E4M3); + + return {ret_rowwise, ret_colwise}; + } + if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + std::vector shape_vec; + for (size_t i = 0; i < shape.ndim; ++i) { + shape_vec.push_back(shape.data[i]); } + size_t first_dim = first_dimension(shape_vec); + size_t last_dim = last_dimension(shape_vec); + + scale_inv_meta ret_rowwise, ret_colwise; + + const size_t block_size_X_rowwise = 32; + size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); + size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); + ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; + + const size_t block_size_Y_colwise = 32; + size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); + size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); + ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; + ret_rowwise.type = DType::kFloat8E8M0; ret_colwise.type = DType::kFloat8E8M0; ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); @@ -178,13 +232,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(last_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(last_dim, 128lu), 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(DIVUP(first_dim, static_cast(128)), 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(DIVUP(first_dim, 128lu), 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -204,13 +258,13 @@ std::pair get_scales(const NVTEShape& shape, scale_inv_meta ret_rowwise, ret_colwise; { - auto scale_dim_0 = DIVUP(last_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(first_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(last_dim, 128lu); + size_t scale_dim_1 = DIVUP(first_dim, 4) * 4; ret_rowwise.shape = {scale_dim_0, scale_dim_1}; } { - auto scale_dim_0 = DIVUP(first_dim, static_cast(128)); - auto scale_dim_1 = DIVUP(last_dim, 4) * 4; + size_t scale_dim_0 = DIVUP(first_dim, 128lu); + size_t scale_dim_1 = DIVUP(last_dim, 4) * 4; ret_colwise.shape = {scale_dim_0, scale_dim_1}; } ret_rowwise.type = DType::kFloat32; @@ -252,14 +306,15 @@ Tensor::Tensor(const std::string& name, NVTEShape columnwise_shape = {}; std::vector columnwise_shape_vec; - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { + if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING + || scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) { // Transpose when tensor scaling columnwise_shape_vec.emplace_back(shape.data[shape.ndim - 1]); for (size_t i = 0; i < shape.ndim - 1; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } } else { - // Same shape for MX + // Same shape for MX and NVFP4 for (size_t i = 0; i < shape.ndim; ++i) { columnwise_shape_vec.emplace_back(shape.data[i]); } @@ -285,10 +340,13 @@ Tensor::Tensor(const std::string& name, std::fill_n(cpu_data_columnwise_.get(), total_size, 0); } } - tensor_.set_rowwise_data(dptr_rowwise, type, shape); - tensor_.set_columnwise_data(dptr_columnwise, type, columnwise_shape); - if (isFp8Type(type)) { + const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; + tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); + tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); + + if (isFp8Type(type) || isFp4Type(type)) { if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { (void)cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) (void)cudaMemset(amax, 0, sizeof(float)); @@ -307,13 +365,19 @@ Tensor::Tensor(const std::string& name, } if (columnwise) { tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); + std::vector{1}); columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); } } else { - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(normalized_shape, tensor_.scaling_mode()); + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + // Used for NVFP4 second stage scaling + cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) + cudaMemset(scale, 0, sizeof(float)); + scale_cpu_data_ = std::make_shared(0); + tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); + } + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(normalized_shape, tensor_.scaling_mode()); auto rowwise_scale_size = rowwise_scale_meta.bytes(); auto columnwise_scale_size = colwise_scale_meta.bytes(); auto scale_shape = rowwise_scale_meta.shape; @@ -348,13 +412,16 @@ void Tensor::to_cpu() const { cudaMemcpyDeviceToHost); } if (columnwise_) { + const DType colwise_type = tensor_.dtype(); + + const size_t colwise_size = bytes(s, colwise_type); (void)cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); + tensor_.get_columnwise_data().data_ptr, + colwise_size, + cudaMemcpyDeviceToHost); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING)) { if (tensor_.amax() != nullptr){ (void)cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), @@ -366,8 +433,7 @@ void Tensor::to_cpu() const { sizeof(float), cudaMemcpyDeviceToHost); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); (void)cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), @@ -396,15 +462,15 @@ void Tensor::from_cpu() const { (void)cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice); } - if (isFp8Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { + if ((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING)) { if (tensor_.amax() != nullptr){ (void)cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } (void)cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(s, tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); if (rowwise_) { auto scale_size = rowwise_scale_meta.bytes(); (void)cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, @@ -421,7 +487,7 @@ void Tensor::from_cpu() const { } void Tensor::set_scale(float scale) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { NVTE_CHECK(scale_cpu_data_); if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { *scale_cpu_data_ = scale; @@ -431,7 +497,7 @@ void Tensor::set_scale(float scale) { } void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype())) { + if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { NVTE_CHECK(rowwise_scale_inv_cpu_data_); } @@ -439,8 +505,7 @@ void Tensor::set_scale_inv(float scale_inv) { NVTE_CHECK(columnwise_scale_inv_cpu_data_); } - auto [rowwise_scale_meta, colwise_scale_meta] = - get_scales(tensor_.shape(), tensor_.scaling_mode()); + auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); if (rowwise_) { auto num_scales = product(rowwise_scale_meta.shape); if (num_scales == 1) { @@ -473,7 +538,8 @@ void Tensor::set_scale_inv(float scale_inv) { } void Tensor::shareFP8Meta(const Tensor &other) { - if (isFp8Type(dtype()) && isFp8Type(other.dtype())) { + if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) + || isFp4Type(dtype()) && isFp4Type(other.dtype())) { auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); auto my_rowwise_data = tensor_.get_rowwise_data(); new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), @@ -686,26 +752,69 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } } -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, - size_t& mismatches_num, const size_t atol, - const double abs_tolerable_mismatches_limit, - const double rel_tolerable_mismatches_limit) +template +struct CastToType; + +template <> +struct CastToType { + using type = int; +}; + +template <> +struct CastToType { + using type = float; +}; + +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + using UpcastType = typename CastToType::type; + auto [atol_fp8e4m3, rtol_fp8e4m3] = getTolerances(DType::kFloat8E4M3); + + const size_t N = row_blocks * col_blocks; const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); mismatches_num = 0; +#ifndef __HIP_PLATFORM_AMD__ + std::vector mismatch_indices; +#endif //#ifndef __HIP_PLATFORM_AMD__ for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - const int test_val = static_cast(test[idx]); - const int ref_val = static_cast(ref[idx]); - const int abs_delta = std::abs(test_val - ref_val); + float t, r; + + bool assertion = false; - if (abs_delta > atol) { + if (std::is_same::value) { + t = static_cast(test[idx]); + r = static_cast(ref[idx]); + assertion = std::abs(t - r) > atol; + } else { + t = static_cast(*reinterpret_cast(&test[idx])); + r = static_cast(*reinterpret_cast(&ref[idx])); + const bool mismatch = (fabs(t - r) > atol_fp8e4m3) + && (r == 0 || fabs((t - r) / r) > rtol_fp8e4m3); + if (mismatch) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + } + if (assertion) { mismatches_num++; mismatch_indices.push_back(idx); } @@ -713,8 +822,8 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, std::cout << "Error in " << name << std::endl; for (const int index : mismatch_indices) { std::cout << "Mismatch at (" << index << "):" - << static_cast(test[index]) << " vs " - << static_cast(ref[index]) << std::endl; + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; } GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " << tolerable_mismatches_limit << "."; @@ -723,7 +832,6 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } - #ifdef __HIP_PLATFORM_AMD__ void adjust_ref_for_e8m0_scale_error(const std::string &name, const std::vector &mismatch_idx, @@ -767,6 +875,27 @@ void adjust_ref_for_e8m0_scale_error(const std::string &name, } } #endif // #ifdef __HIP_PLATFORM_AMD__ +// Instantiate templates +template +void compare_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + +template +void compare_scaling_factors(const std::string &name, const fp8e4m3 *test, const fp8e4m3 *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef __HIP_PLATFORM_AMD__ + std::vector &mismatch_indices, +#endif //#ifdef __HIP_PLATFORM_AMD__ + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit); + std::pair getTolerances(const DType type) { switch(type) { @@ -932,11 +1061,14 @@ bool isFp8Type(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2 || type == DType::kFloat8E8M0; } -int32_t getDeviceComputeCapability() -{ - cudaDeviceProp deviceProp; - (void)cudaGetDeviceProperties(&deviceProp, 0); - return 10 * deviceProp.major + deviceProp.minor; +bool isFp4Type(DType type) { + return type == DType::kFloat4E2M1; +} + +int32_t getDeviceComputeCapability() { + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, 0); + return 10 * deviceProp.major + deviceProp.minor; } size_t first_dimension(const std::vector &shape) { @@ -954,7 +1086,8 @@ std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols) { - const bool is_rowwise = (block_size_rows == 1) && (block_size_cols == 32); + const bool is_rowwise = (block_size_rows == 1) + && ((block_size_cols == 32) || (block_size_cols == 16)); const size_t alignment_Y = is_rowwise ? scale_tensor_alignment_Y_rowwise diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 3c0a387c6..56154c9d9 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -74,6 +74,8 @@ using fp8e5m2 = te_hip_fp8_e5m2; using fp8e8m0 = uint8_t; #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif template @@ -235,7 +237,9 @@ class Tensor { float scale() const { if(scale_cpu_data_) { - NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); + NVTE_CHECK((tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) + || (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING), + "Invalid scaling_mode!"); to_cpu(); return *scale_cpu_data_; } else { @@ -249,6 +253,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -262,6 +268,8 @@ class Tensor { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); + } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { + NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); } else { NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); } @@ -316,10 +324,10 @@ constexpr uint32_t FP32_EXPONENT_BIAS = 127; constexpr uint32_t FP32_MANTISSA_BITS = 23; // [128,4] rowwise and [4,128] colwise alignment requirement -constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_rowwise = 128; -constexpr size_t scale_tensor_alignment_X_colwise = 128; +constexpr size_t scale_tensor_alignment_X_rowwise = 4; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +constexpr size_t scale_tensor_alignment_X_colwise = 128; inline size_t divide_round_up(const size_t N, const size_t M) { return (N - 1 + M) / M; @@ -480,12 +488,16 @@ void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - std::vector &mismatch_indices, size_t& mismatches_num, - const size_t scale_diff_abs_tolerance = 0, - const double abs_tolerable_mismatches_limit = 0, - const double rel_tolerable_mismatches_limit = 0); +template +void compare_scaling_factors(const std::string &name, const T *test, const T *ref, + const size_t row_blocks, const size_t col_blocks, const size_t stride, +#ifdef USE_ROCM + std::vector& mismatch_indices, +#endif //#ifdef USE_ROCM + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); #ifdef USE_ROCM void adjust_ref_for_e8m0_scale_error(const std::string &name, @@ -516,6 +528,7 @@ const std::string& caseName(InputsFillCase type); extern std::vector all_fp_types; bool isFp8Type(DType type); +bool isFp4Type(DType type); int32_t getDeviceComputeCapability(); constexpr int32_t hopperComputeCapability = 90; @@ -593,7 +606,7 @@ constexpr int32_t blackwellComputeCapability = 100; SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \ default: \ printf("dtype: %d\n", static_cast(dtype)); \ - NVTE_ERROR("Invalid type MARKED TEST."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \ @@ -612,7 +625,7 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 2."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \ @@ -620,7 +633,7 @@ constexpr int32_t blackwellComputeCapability = 100; using namespace transformer_engine; \ SWITCH_FP4_HANDLE(type, __VA_ARGS__) \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 3."); \ + NVTE_ERROR("Invalid type."); \ } #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \ @@ -645,5 +658,5 @@ constexpr int32_t blackwellComputeCapability = 100; } \ break; \ default: \ - NVTE_ERROR("Invalid type MARKED TEST 4."); \ + NVTE_ERROR("Invalid type."); \ } diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu index 8355d5f96..884faa474 100644 --- a/tests/cpp_distributed/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -69,6 +69,34 @@ bool IsMulticastSupported(int device_id) { return supported; } +int GetDeviceComputeCapability(int device_id) { + int major{}; + int minor{}; + CHECK_CU(cuDeviceGetAttribute(&major, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device_id)); + CHECK_CU(cuDeviceGetAttribute(&minor, CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR, device_id)); + return major * 10 + minor; +} + +template +bool IsDTypeSupported(int /* device_id */) { + return true; +} + +template <> +bool IsDTypeSupported(int device_id) { + return GetDeviceComputeCapability(device_id) >= 89; +} + +template <> +bool IsDTypeSupported(int device_id) { + return GetDeviceComputeCapability(device_id) >= 89; +} + +template +bool AllDTypesSupported(int device_id) { + return (IsDTypeSupported(device_id) && ...); +} + template std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, size_t nsize, size_t ld) { @@ -161,6 +189,9 @@ class CommGemmFixure : public ::testing::TestWithParam { template void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + if (!AllDTypesSupported(rank_)) + GTEST_SKIP() << "FP8 is not supported on device " << rank_; + cudaStream_t stream{}; NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 3f3b5db84..137fa480d 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,5 +1,3 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,7 +8,6 @@ import pytest import jax -from jax._src.pjit import pjit from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource @@ -20,14 +17,6 @@ def generate_configs(): configs = [] - if is_devices_enough(2): - configs.append( - pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") - ) - configs.append( - pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2") - ) - if is_devices_enough(4): configs.append( pytest.param( @@ -35,10 +24,17 @@ def generate_configs(): (2, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp"), - id=f"n4_dp2_tp2", + id="n4_dp2_tp2", ) ) + if is_devices_enough(2): + configs.append( + pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") + ) + configs.append( + pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2"), + ) return configs @@ -158,13 +154,15 @@ def compare_ops( grad_args = tuple(range(len(inputs))) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) - target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) - target_fwd, target_grads = target_pjitter(*inputs, **kwargs) - target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text() + target_jitter = jax.jit( + target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings + ) + target_fwd, target_grads = target_jitter(*inputs, **kwargs) + target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text() ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args) - ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) - ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs) + ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) + ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs) assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype) diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh index fcb066de7..d430e0f41 100644 --- a/tests/jax/multi_process_launch.sh +++ b/tests/jax/multi_process_launch.sh @@ -18,6 +18,14 @@ do CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 & done -CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS | tee stdout_multi_process.txt wait + +RET=0 +if grep -q "FAILED" stdout_multi_process.txt; then + RET=1 +fi + +rm -f stdout_multi_process.txt +exit "$RET" diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 6ec3c27a4..9303d6da8 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -43,6 +43,8 @@ QuantizerFactory, QuantizeLayout, noop_quantizer_set, + QuantizeMetaSet, + QuantizeMeta, ) from transformer_engine.jax.quantize import helper from transformer_engine.jax.activation import activation @@ -64,16 +66,23 @@ FP8_COMPUTE_TYPE = [jnp_float8_e4m3_type, jnp_float8_e5m2_type] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] -is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available() -is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) -supported_scaling_modes = [] +# TODO(Phuong): remove unneccessary pytest skips +is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.DELAYED_TENSOR_SCALING +) +is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.MXFP8_1D_SCALING +) +is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported( + ScalingMode.NVFP4_1D_SCALING +) + """ Find supported scaling modes""" -if is_fp8_supported: - supported_scaling_modes.append(ScalingMode.DELAYED_TENSOR_SCALING) - supported_scaling_modes.append(ScalingMode.CURRENT_TENSOR_SCALING) -if is_mxfp8_supported: - supported_scaling_modes.append(ScalingMode.MXFP8_1D_SCALING) +supported_scaling_modes = helper.get_supported_scaling_modes() +non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling] +supported_recipes = helper.get_supported_quantization_recipes() +supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes] def is_shape_supported_by_mxfp8(input_shape): @@ -91,12 +100,13 @@ def assert_bitwise_scaled_tensors( a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True ): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): - if not precise_comparison: + if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling: assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) return assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype + assert a.data_layout == b.data_layout if a.scaling_mode.is_tensor_scaling(): # Assert in dq_dtype as some unfused codepaths have an intermediate cast # to an input dtype which reduces precision compared to everything in fp32 @@ -104,6 +114,16 @@ def assert_bitwise_scaled_tensors( elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Compare MXFP8 scales as uint8 assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8)) + elif a.scaling_mode.is_nvfp4_scaling: + assert_allclose(a.amax, b.amax) + assert_allclose(a.scale_inv, b.scale_inv) + if not precise_comparison: + mismatch = a.data != b.data + mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32)) + assert ( + mismatch_fraction < 0.05 + ), f"Mismatch fraction {mismatch_fraction} is too high" + return else: raise ValueError(f"Unsupported scaling mode {a.scaling_mode}") assert_allclose(a.data, b.data) @@ -178,6 +198,7 @@ def assert_dequantized_grouped_scaled_tensor( ("quick_gelu", "linear"), ("squared_relu",), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] ACTIVATION_TYPES = { @@ -190,17 +211,21 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: - def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type).data + def ref_act(self, x, activation_type, act_params): + return _jax_act_lu(x, activation_type, act_params=act_params).data - def value_n_grad_ref_func(self, x, activation_type): + def value_n_grad_ref_func(self, x, activation_type, act_params): jitted_reference = jit( - value_and_grad(lambda out: jnp.mean(self.ref_act(out, activation_type)), (0,)) + value_and_grad( + lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,) + ) ) return jitted_reference(x) - def primitive_func(self, inputs, activation_type, quantizer): - out = activation(inputs, activation_type=activation_type, quantizer=quantizer) + def primitive_func(self, inputs, activation_type, quantizer, act_params): + out = activation( + inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params + ) return jnp.mean(out) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @@ -217,12 +242,20 @@ def test_act_grad(self, shape, activation_type): x = jnp.repeat(x, len(activation_type), axis=-2) value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3) ) - - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) @@ -242,7 +275,8 @@ def test_act_grad_with_tensor_scaling_fp8( self.activation_type = activation_type value_n_grad_primitive_func = jit( - value_and_grad(self.primitive_func, (0,)), static_argnums=(1,) + value_and_grad(self.primitive_func, (0,)), + static_argnums=(1, 3), ) quantizer = QuantizerFactory.create( @@ -250,9 +284,21 @@ def test_act_grad_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=QuantizeLayout.ROWWISE, ) + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) - prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, quantizer) - ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + prim_out, (prim_grad,) = value_n_grad_primitive_func( + x, activation_type, quantizer, act_params + ) + ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params) assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) @@ -281,10 +327,18 @@ def test_act_forward_with_tensor_scaling_fp8( q_dtype=output_type, q_layout=q_layout, ) - - te_output = tex.act_lu(x, activation_type, te_quantizer) - jax_output = _jax_act_lu(x, activation_type, jax_quantizer) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + te_output = tex.act_lu(x, activation_type, te_quantizer, act_params) + jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params) assert_bitwise_scaled_tensors(te_output, jax_output) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @@ -304,10 +358,18 @@ def test_act_forward_with_block_scaling_fp8( quantizer = QuantizerFactory.create( scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout ) - - output = tex.act_lu(x, activation_type, quantizer) - ref_out = self.ref_act(x, activation_type) - + act_args = ( + {"limit": 0.75, "alpha": 1.702} + if activation_type == ("clamped_silu", "clamped_linear") + else {} + ) + act_params = ( + tex.activation.ActivationParams.create(activation_type=activation_type, **act_args) + if activation_type == ("clamped_silu", "clamped_linear") + else None + ) + output = tex.act_lu(x, activation_type, quantizer, act_params) + ref_out = self.ref_act(x, activation_type, act_params) assert_dequantized_scaled_tensor(output, ref_out) @@ -552,7 +614,12 @@ def test_norm_forward_with_tensor_scaling_fp8( ) @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) - @pytest.mark.parametrize("out_dtype", FP8_COMPUTE_TYPE) + @pytest.mark.parametrize( + "out_dtype", + [ + jnp_float8_e4m3_type if is_hip_extension() else jnp.float8_e4m3fn, + ], + ) def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype ): @@ -569,10 +636,24 @@ def test_norm_forward_with_block_scaling_fp8( ) -QUANTIZE_OUTPUT_DTYPES = { +QUANTIZE_OUTPUT_FP8_DTYPES = { "L0": [jnp_float8_e4m3_type], "L2": FP8_COMPUTE_TYPE, } +QUANTIZE_OUTPUT_DTYPES = { + test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn] + for test_level in QUANTIZE_OUTPUT_FP8_DTYPES +} +QUANTIZE_QDTYPE_AND_SCALING_MODES = { + test_level: [ + (q_dtype, scaling_mode) + for q_dtype, scaling_mode in zip( + QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes + ) + if q_dtype in scaling_mode.get_compatible_q_dtypes() + ] + for test_level in QUANTIZE_OUTPUT_FP8_DTYPES +} ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [ ((32, 64), -1), @@ -581,8 +662,7 @@ def test_norm_forward_with_block_scaling_fp8( ((32, 256, 128), -1), ((32, 256, 128), -2), ((64, 32, 32, 256), -1), - ((64, 32, 32, 256), -2), - ((64, 32, 32, 256), -3), + ((8192, 2, 4096), -2), ] QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = { @@ -606,14 +686,27 @@ def test_norm_forward_with_block_scaling_fp8( @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper( - "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] + "q_layout", + [ + QuantizeLayout.ROWWISE, + QuantizeLayout.COLWISE, + QuantizeLayout.ROWWISE_COLWISE, + ], ) class TestQuantize: """ Purely quantization related tests that will always test on a wider set of types and shapes """ + def _skip_unsupported_dtypes(self, q_dtype, scaling_mode): + """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes.""" + if q_dtype not in scaling_mode.get_compatible_q_dtypes(): + pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") + return + def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + self._skip_unsupported_dtypes(q_dtype, scaling_mode) + key = jax.random.PRNGKey(0) # Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling) @@ -623,6 +716,68 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt q_layout=q_layout, ) + if scaling_mode.is_nvfp4_scaling: + if in_dtype != jnp.bfloat16: + pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently") + return + q_func = _jax_quantize + # For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor. + x = jax.random.uniform(key, input_shape, in_dtype) * 10 + q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis) + + dq_rowwise = None + dq_colwise = None + if isinstance(q1, ScaledTensor1x): + dq = q1.dequantize() + if q1.is_colwise: + dq_colwise = dq + else: + dq_rowwise = dq + elif isinstance(q1, ScaledTensor2x): + dq_rowwise = q1.rowwise_tensor.dequantize() + dq_colwise = q1.colwise_tensor.dequantize() + else: + raise ValueError(f"Unsupported output type {type(q1)}") + + # We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization. + if dq_rowwise is not None: + assert ( + dq_rowwise.shape == x.shape + ), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}" + q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis) + q2_rowwise = ( + q2_rowwise + if isinstance(q2_rowwise, ScaledTensor1x) + else q2_rowwise.rowwise_tensor + ) + q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor + assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise) + + if dq_colwise is not None: + # Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape + flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis + colwise_flatten_axis = len(input_shape) - flatten_axis + dq_colwise = jnp.transpose( + dq_colwise, + (*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)), + ) + assert ( + dq_colwise.shape == x.shape + ), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}" + q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis) + q2_colwise = ( + q2_colwise + if isinstance(q2_colwise, ScaledTensor1x) + else q2_colwise.colwise_tensor + ) + q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor + assert_bitwise_scaled_tensors(q1_colwise, q2_colwise) + + assert ( + dq_rowwise is not None or dq_colwise is not None + ), "At least one of rowwise or colwise dq must be not None" + return + n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 for _ in range(n_iterations): x = jax.random.uniform(key, input_shape, in_dtype) @@ -630,9 +785,19 @@ def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatt scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis) assert_dequantized_scaled_tensor(scaled_tensor, x) + def _should_use_precise_comparison( + self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis + ): + if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: + # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation + return False + + return True + def test_quantize_bitwise( self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis ): + self._skip_unsupported_dtypes(q_dtype, scaling_mode) key = jax.random.PRNGKey(0) input = jax.random.uniform(key, input_shape, in_dtype) @@ -644,14 +809,270 @@ def test_quantize_bitwise( jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + + assert_bitwise_scaled_tensors( + te_output, + jax_output, + precise_comparison=self._should_use_precise_comparison( + in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis + ), + ) + + def test_quantize_bitwise_jitted( + self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis + ): + self._skip_unsupported_dtypes(q_dtype, scaling_mode) + + key = jax.random.PRNGKey(0) + input = jax.random.uniform(key, input_shape, in_dtype) + + te_quantizer, jax_quantizer = QuantizerFactory.create( + n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout + ) + + jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) + te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) + + jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) + + te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis) + + assert_bitwise_scaled_tensors( + te_output, + jax_output, + precise_comparison=self._should_use_precise_comparison( + in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis + ), + ) + + +@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) +@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) +@pytest_parametrize_wrapper( + "scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling] +) +@pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] +) +class TestStochasticRounding: + + def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]: + """Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor.""" + if isinstance(scaled_tensor, ScaledTensor1x): + dq = scaled_tensor.dequantize() + if scaled_tensor.data_layout == "T": + dq = jnp.transpose( + dq, + ( + *range(scaled_tensor.flatten_axis, dq.ndim), + *range(scaled_tensor.flatten_axis), + ), + ) + return [dq] + elif isinstance(scaled_tensor, ScaledTensor2x): + [rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor) + [colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor) + return [rowwise_dq, colwise_dq] + raise ValueError( + "Unsupported ScaledTensor type, expected ScaledTensor but received" + f" {type(scaled_tensor)}" + ) + + def _sample_sr_qdq( + self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> list[jnp.ndarray]: + """Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors.""" + dq_tensors = [] + + key = jax.random.PRNGKey(0) + + for i in range(num_samples): + iter_key = jax.random.fold_in(key, i) + sr_rng_state = jax.random.randint( + iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32 + ) + quantizer = QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + stochastic_rounding_rng_state=sr_rng_state, + ) + + q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) + iter_dq = self._dequantize(q_output) + dq_tensors.extend(iter_dq) + + avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0) + assert avg_sr_tensor.shape == inputs.shape, ( + f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" + f" {inputs.shape}" + ) + + sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) + + dq_var = jnp.var(jnp.stack(dq_tensors)) + assert ( + dq_var > 0 + ), "Variance of dequantized tensors is zero, stochastic rounding may not be working" + + return dq_tensors + + def _round_nearest( + self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> jnp.ndarray: + """Quantizes and dequantizes the input tensor with round nearest quantization.""" + quantizer = QuantizerFactory.create( + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + stochastic_rounding_rng_state=None, + ) + q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis) + rn_dq = self._dequantize(q_output)[0] + return rn_dq + + def _test_sr( + self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) -> float: + """Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples.""" + dq_tensors = self._sample_sr_qdq( + num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0) + assert avg_sr_tensor.shape == inputs.shape, ( + f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape" + f" {inputs.shape}" + ) + + round_nearest_tensor = self._round_nearest( + q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + + sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs)) + rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs)) + + assert sr_mae < rn_mae, ( + f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than" + f" round nearest ({rn_mae})" + ) + + return sr_mae + + def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): + """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" + + key = jax.random.PRNGKey(0) + inputs = jax.random.uniform(key, input_shape, in_dtype) + + NUM_SAMPLES = 10 + + te_mean_error = self._test_sr( + NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + jax_mean_error = self._test_sr( + NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis + ) + + assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) + + +@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16]) +@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn]) +@pytest_parametrize_wrapper( + "scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING] +) +class TestRandomizedHadamardTransform: + + @pytest_parametrize_wrapper( + "q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE] + ) + @pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)]) + def test_rht_quantize_bitwise_jitted( + self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis + ): + key = jax.random.PRNGKey(0) + inputs = jax.random.uniform(key, input_shape, in_dtype) + + te_quantizer, jax_quantizer = QuantizerFactory.create( + n_quantizers=2, + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + use_rht=True, + ) + + jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3)) + te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,)) + + jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis) + + te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis) + assert_bitwise_scaled_tensors(te_output, jax_output) + def _ref_gemm_with_jnp_dot(self, a, b, data_layout): + if data_layout[0] == "T": + a = jnp.swapaxes(a, -1, -2) + if data_layout[1] == "T": + b = jnp.swapaxes(b, -1, -2) + return jnp.dot(a, b) + + def _generate_gemm_input(self, m, n, k, data_layout): + key = jax.random.PRNGKey(0) + subkeys = jax.random.split(key, 2) + x = jax.random.uniform( + subkeys[0], + (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m), + dtype=jnp.bfloat16, + ) / jnp.sqrt(k) + w = jax.random.uniform( + subkeys[1], + (k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k), + dtype=jnp.bfloat16, + ) / jnp.sqrt(n) + lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,) + rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,) + contracting_dims = (lhs_contracting_dim, rhs_contracting_dim) + + return (x, w, contracting_dims) + + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + # We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently + @pytest_parametrize_wrapper("data_layout", ["TN", "NT"]) + @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) + def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm): + key = jax.random.PRNGKey(0) + + lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + lhs_quantizer = QuantizerFactory.create( + scaling_mode=lhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + use_rht=True, + ) + rhs_quantizer = QuantizerFactory.create( + scaling_mode=rhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + use_rht=True, + ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + ) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("q_dtype", [jnp_float8_e4m3_type]) -@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) +@pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("flatten_axis", [-1]) @pytest_parametrize_wrapper("with_group_sizes", [True, False]) @pytest_parametrize_wrapper( @@ -690,7 +1111,6 @@ def test_grouped_qdq( q_layout=q_layout, n_groups=n_groups, ) - scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) @@ -702,9 +1122,8 @@ def test_grouped_qdq( class TestFusedQuantize: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype,scaling_mode", QUANTIZE_QDTYPE_AND_SCALING_MODES) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] ) @@ -742,6 +1161,7 @@ def test_quantize_dbias( def _test_quantize_dact_dbias( self, in_dtype, input_shape, out_dtype, scaling_mode, activation_type, is_dbias, q_layout ): + key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 2) x = jax.random.uniform(subkeys[0], input_shape, in_dtype, -1, 1) @@ -788,9 +1208,15 @@ def _test_quantize_dact_dbias( assert_allclose(te_output.data, jax_output.data) if is_dbias: - # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. precise_comparison = not ( - in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + (in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling()) + # Due to the amax dependency, current scaling is unfused. In TE we store the activation results in bf16 which reduces precision compared to JAX implementation which will implicitly promote to float32 for the intermediate results when JIT'd. This only produces a tolerance issue when using squared_relu currently. + or ( + activation_type in {("squared_relu",), ("clamped_silu", "clamped_linear")} + and in_dtype == jnp.bfloat16 + and scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + ) ) assert_allclose( te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype @@ -819,7 +1245,7 @@ def test_quantize_dact_dbias_no_quantization( @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE] @@ -845,7 +1271,7 @@ def test_quantize_dact_dbias_tensor_scaling( @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] ) - @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) + @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_FP8_DTYPES) @pytest_parametrize_wrapper("is_dbias", [True, False]) @pytest_parametrize_wrapper( "q_layout", [QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE] @@ -878,6 +1304,11 @@ def test_quantize_dact_dbias_mxfp8_scaling( (jnp_float8_e4m3_type, jnp_float8_e5m2_type), ] +supported_nvfp4_scaling_mode_pairs = [ + (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_1D_SCALING), + (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING), +] + class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): @@ -919,7 +1350,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("x_qtype,w_qtype", valid_fp8_gemm_operand_types) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, with_jax_gemm): @@ -953,6 +1384,40 @@ def test_gemm_fp8(self, m, n, k, x_qtype, w_qtype, scaling_mode, data_layout, wi assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) + # TODO(Phuong): add bitwise test + @pytest.mark.skipif(not is_fp4_supported, reason=fp4_unsupported_reason) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) + @pytest_parametrize_wrapper("scaling_mode_pair", supported_nvfp4_scaling_mode_pairs) + @pytest_parametrize_wrapper("data_layout", ["TN", "NT", "NN", "TT"]) + @pytest_parametrize_wrapper("with_jax_gemm", [True, False]) + def test_gemm_nvfp4(self, m, n, k, scaling_mode_pair, data_layout, with_jax_gemm): + x_uses_rht = scaling_mode_pair[0] == ScalingMode.NVFP4_1D_SCALING and data_layout[0] == "T" + w_uses_rht = scaling_mode_pair[1] == ScalingMode.NVFP4_1D_SCALING and data_layout[1] == "N" + if x_uses_rht != w_uses_rht: + # TODO(jberchtold): Ideally avoid a skip here and rewrite test setup to ensure both or neither use RHT + pytest.skip("RHT must be used for both or neither operand, skipping") + + lhs_scaling_mode, rhs_scaling_mode = scaling_mode_pair + x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) + lhs_quantizer = QuantizerFactory.create( + scaling_mode=lhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + ) + rhs_quantizer = QuantizerFactory.create( + scaling_mode=rhs_scaling_mode, + q_dtype=jnp.float4_e2m1fn, + ) + with use_jax_gemm(enabled=with_jax_gemm): + primitive_out = tex.gemm( + x, + w, + contracting_dims=contracting_dims, + lhs_quantizer=lhs_quantizer, + rhs_quantizer=rhs_quantizer, + ) + ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout) + assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn) + @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) def test_dense_grad_bf16(self, m, n, k): data_layout = "NN" @@ -978,11 +1443,10 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("m,n,k", [(64, 128, 128)]) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_dense_grad_fp8(self, m, n, k, scaling_mode, with_jax_gemm): + def test_dense_grad_fp8_and_fp4(self, m, n, k, recipe, with_jax_gemm): data_layout = "NN" x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout) @@ -1004,13 +1468,13 @@ def ref_func(x, w, bias, data_layout): value_n_grad_ref_func = value_and_grad(ref_func, (0, 1, 2)) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), ) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): primitive_out, (primitive_x_grad, primitive_w_grad, primitive_bias_grad) = ( @@ -1021,10 +1485,10 @@ def ref_func(x, w, bias, data_layout): x, w, bias, data_layout ) - assert_allclose(primitive_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(primitive_out, ref_out, dtype=quantizer_set.x.q_dtype) + assert_allclose(primitive_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(primitive_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(primitive_bias_grad, ref_bias_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.fixture(name="random_inputs") @@ -1049,11 +1513,11 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) - @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) + @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) - def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_gemm): + def test_layernorm_dense_grad(self, m, n, k, recipe, norm_type, with_jax_gemm): """ Test layernorm_dense VJP Rule """ @@ -1071,10 +1535,10 @@ def test_layernorm_dense_grad(self, m, n, k, scaling_mode, norm_type, with_jax_g gamma = jax.random.normal(subkeys[2], (k,)).astype(jnp.bfloat16) quantizer_set = QuantizerFactory.create_set( - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), ) if norm_type == "layernorm": @@ -1110,7 +1574,7 @@ def ref_func(x, w, gamma, beta): x, w, gamma, beta ) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( @@ -1120,22 +1584,22 @@ def ref_func(x, w, gamma, beta): prim_beta_grad, ) = value_n_grad_prim_func(x, w, gamma, beta) - assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_w_grad, ref_w_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(prim_out, ref_out, dtype=quantizer_set.x.q_dtype) + assert_allclose(prim_x_grad, ref_x_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(prim_w_grad, ref_w_grad, dtype=quantizer_set.dgrad.q_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=quantizer_set.dgrad.q_dtype) if beta is not None: - assert_allclose(prim_beta_grad, ref_beta_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(prim_beta_grad, ref_beta_grad, dtype=quantizer_set.dgrad.q_dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) - @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) + @pytest.mark.parametrize("m,n,k", [(64, 128, 128)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) - @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("recipe", supported_recipes) @pytest.mark.parametrize("norm_type", ["layernorm", "rmsnorm"]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( - self, m, n, k, activation_type, scaling_mode, norm_type, use_bias, with_jax_gemm + self, m, n, k, activation_type, recipe, norm_type, use_bias, with_jax_gemm ): """ Test layernorm_mlp VJP Rule @@ -1163,10 +1627,10 @@ def test_layernorm_mlp_grad( quantizer_sets = QuantizerFactory.create_set( n_quantizer_sets=2, - scaling_mode=scaling_mode, - fwd_dtype=jnp_float8_e4m3_type, - bwd_dtype=jnp_float8_e5m2_type if scaling_mode.is_tensor_scaling() else jnp_float8_e4m3_type, - is_2x2x=True, + fp8_recipe=recipe, + quantize_meta_set=QuantizeMetaSet( + x=QuantizeMeta(), kernel=QuantizeMeta(), grad=QuantizeMeta() + ), ) if norm_type == "layernorm": @@ -1213,7 +1677,7 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): value_n_grad_prim_func = value_and_grad(prim_func, range(6)) value_n_grad_ref_func = value_and_grad(ref_func, range(6)) - n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1 + n_iterations = 3 if recipe.delayed() else 1 with use_jax_gemm(enabled=with_jax_gemm): for _ in range(n_iterations): prim_out, ( @@ -1234,18 +1698,16 @@ def ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2): ref_bias_2_grad, ) = value_n_grad_ref_func(x, gamma, kernel_1, kernel_2, bias_1, bias_2) - assert_allclose(prim_out, ref_out, dtype=jnp_float8_e4m3_type) - - assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=jnp_float8_e5m2_type) - if use_bias: - assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=jnp_float8_e5m2_type) - - assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=jnp_float8_e5m2_type) + fwd_dtype = quantizer_sets[0].x.q_dtype + bwd_dtype = quantizer_sets[0].dgrad.q_dtype + assert_allclose(prim_out, ref_out, dtype=fwd_dtype) + assert_allclose(prim_kernel_2_grad, ref_kernel_2_grad, dtype=bwd_dtype) + assert_allclose(prim_kernel_1_grad, ref_kernel_1_grad, dtype=bwd_dtype) + assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=bwd_dtype) + assert_allclose(prim_x_grad, ref_x_grad, dtype=bwd_dtype) if use_bias: - assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=jnp_float8_e5m2_type) - - assert_allclose(prim_gamma_grad, ref_gamma_grad, dtype=jnp_float8_e5m2_type) - assert_allclose(prim_x_grad, ref_x_grad, dtype=jnp_float8_e5m2_type) + assert_allclose(prim_bias_2_grad, ref_bias_2_grad, dtype=bwd_dtype) + assert_allclose(prim_bias_1_grad, ref_bias_1_grad, dtype=bwd_dtype) # E5M2 * E5M2 is not supported @@ -1324,25 +1786,37 @@ def _assert_grouped_gemm_output(self, out, group_sizes, ref_list, dtype): @pytest_parametrize_wrapper("dtype", [jnp.bfloat16, jnp.float16]) @pytest_parametrize_wrapper("layout", ["NN"]) - def test_grouped_gemm_fp16(self, dtype, input_shape, layout): + @pytest_parametrize_wrapper("use_async_d2h_group_size", [True, False]) + def test_grouped_gemm_fp16(self, dtype, input_shape, layout, use_async_d2h_group_size): lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( dtype, input_shape, layout ) + if use_async_d2h_group_size: + if is_hip_extension(): + pytest.skip("ROCm does not support use_async_d2h_group_sizes yet.") + num_gemms = input_shape[0] + _ = jax.jit(tex.grouped_gemm_copy_group_sizes, static_argnames=("num_gemms",))( + group_sizes, + num_gemms=num_gemms, + ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm - prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + prim_out = jax.jit( + tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") + )( lhs, rhs, group_sizes, contracting_dims, + use_async_d2h_group_sizes=use_async_d2h_group_size, ) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout): fwd_dtype, bwd_dtype = fwd_bwd_dtype @@ -1423,7 +1897,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): "fwd_bwd_dtype", [(jnp_float8_e4m3_type, jnp_float8_e4m3_type), (jnp_float8_e4m3_type, jnp_float8_e5m2_type)], ) - @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) + @pytest_parametrize_wrapper("scaling_mode", non_fp4_supported_scaling_modes) def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 diff --git a/tests/jax/test_distributed_dense.py b/tests/jax/test_distributed_dense.py new file mode 100644 index 000000000..15b146343 --- /dev/null +++ b/tests/jax/test_distributed_dense.py @@ -0,0 +1,253 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import unittest + +import jax +import jax.numpy as jnp +import numpy as np +from jax import random +from jax.sharding import Mesh, NamedSharding, PartitionSpec +from functools import partial + +from distributed_test_base import generate_configs +from utils import assert_allclose, pytest_parametrize_wrapper + +import transformer_engine.jax.cpp_extensions as tex +from transformer_engine.jax import autocast +from transformer_engine.jax.dense import dense + + +DTYPES = [jnp.bfloat16] + +GEMM_INPUT_SHAPES = [[256, 128, 256]] # [batch, seq_len, hidden_in] + +WEIGHT_SHAPES = [[256, 256]] # [hidden_in, hidden_out] + + +def _generate_inputs(input_shape, weight_shape, dtype): + """Generate test inputs for GEMM operations""" + _, _, hidden_in = input_shape + hidden_in_w, hidden_out = weight_shape + assert hidden_in == hidden_in_w, f"Dimension mismatch: {hidden_in} != {hidden_in_w}" + + bias_shape = (hidden_out,) + + # Generate random inputs + x = random.normal(random.PRNGKey(1124), input_shape, dtype=dtype) + weight = random.normal(random.PRNGKey(2248), weight_shape, dtype=dtype) / jnp.sqrt(hidden_in_w) + bias = random.normal(random.PRNGKey(3372), bias_shape, dtype=dtype) / jnp.sqrt(hidden_out) + + return x, weight, bias + + +def _get_sharding_for_gemm(mesh, mesh_resource, partition_layout="rowwise"): + """Get sharding patterns for GEMM inputs and outputs""" + + dp_axis = mesh_resource.dp_resource + tp_axis = mesh_resource.tpsp_resource + + if partition_layout == "colwise": + x_spec = PartitionSpec(dp_axis, None, None) + weight_spec = PartitionSpec(None, tp_axis) + bias_spec = PartitionSpec(tp_axis) + output_spec = PartitionSpec(dp_axis, None, tp_axis) + elif partition_layout == "rowwise": + x_spec = PartitionSpec(dp_axis, None, tp_axis) + weight_spec = PartitionSpec(tp_axis, None) + bias_spec = PartitionSpec(None) + output_spec = PartitionSpec(dp_axis, None, None) + else: + raise ValueError(f"Invalid partition: {partition_layout}") + + x_sharding = NamedSharding(mesh, x_spec) + weight_sharding = NamedSharding(mesh, weight_spec) + bias_sharding = NamedSharding(mesh, bias_spec) + output_sharding = NamedSharding(mesh, output_spec) + + return x_sharding, weight_sharding, bias_sharding, output_sharding + + +@partial(jax.jit, static_argnames=("contracting_dims", "output_sharding")) +def _jitted_gemm(x, weight, bias, contracting_dims, output_sharding): + output = tex.gemm( + x, + weight, + bias=bias, + contracting_dims=contracting_dims, + fuse_bias=True, + ) + if output_sharding is not None: + output = jax.lax.with_sharding_constraint(output, output_sharding) + return output + + +# TODO(Phuong): +# 1. Add supported recipes after FP4 is added +# 2. Add communication type/byte checks +class TestDistributedDense: + """Test distributed GEMM without collective operations vs JAX dot""" + + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_configs(), + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES) + @pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES) + @pytest_parametrize_wrapper("partition", ["rowwise", "colwise"]) + def test_distributed_gemm( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + dtype, + input_shape, + weight_shape, + partition, + ): + """Test TE GEMM against JAX dot with bf16 dtype""" + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Generate inputs + x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype) + + # Get sharding patterns + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm( + mesh, mesh_resource, partition_layout=partition + ) + + # Shard inputs + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + contracting_dims = ((2,), (0,)) # Contract on hidden_in dimension + + with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + # TE GEMM result + te_result = _jitted_gemm( + x_sharded, + weight_sharded, + bias_sharded, + contracting_dims=contracting_dims, + output_sharding=output_sharding, + ) + + # JAX dot reference result + jax_result = ( + jax.lax.dot_general( + x_sharded, weight_sharded, dimension_numbers=(contracting_dims, ((), ())) + ) + + bias_sharded + ) + + assert te_result.sharding == jax_result.sharding + # Ensure computation is complete + jax.block_until_ready(te_result) + jax.block_until_ready(jax_result) + + # Gather results for comparison + gathered_te = jax.lax.with_sharding_constraint( + te_result, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_jax = jax.lax.with_sharding_constraint( + jax_result, NamedSharding(mesh, PartitionSpec(None)) + ) + + # Compare results + assert_allclose(gathered_te, gathered_jax, dtype=dtype) + + def _te_sum_dense(self, x, weight, bias, contracting_dims): + """TE GEMM function for gradient testing""" + return jnp.sum(dense(x, weight, bias=bias, contracting_dims=contracting_dims)) + + def _jax_sum_dense(self, x, weight, bias, contracting_dims): + """JAX dot function for gradient testing""" + result = ( + jax.lax.dot_general(x, weight, dimension_numbers=(contracting_dims, ((), ()))) + bias + ) + return jnp.sum(result) + + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_configs(), + ) + @pytest_parametrize_wrapper("dtype", DTYPES) + @pytest_parametrize_wrapper("input_shape", GEMM_INPUT_SHAPES) + @pytest_parametrize_wrapper("weight_shape", WEIGHT_SHAPES) + @pytest_parametrize_wrapper("partition", ["rowwise", "colwise"]) + def test_te_distributed_dense_grad( + self, + device_count, + mesh_shape, + mesh_axes, + mesh_resource, + dtype, + input_shape, + weight_shape, + partition, + ): + """Test TE GEMM gradients against JAX dot gradients""" + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = Mesh(devices, mesh_axes) + + # Generate inputs + x, weight, bias = _generate_inputs(input_shape, weight_shape, dtype) + + # Get sharding patterns + x_sharding, weight_sharding, bias_sharding, output_sharding = _get_sharding_for_gemm( + mesh, mesh_resource, partition_layout=partition + ) + + x_sharded = jax.device_put(x, x_sharding) + weight_sharded = jax.device_put(weight, weight_sharding) + bias_sharded = jax.device_put(bias, bias_sharding) + + contracting_dims = ((2,), (0,)) + + with mesh, autocast(enabled=False, mesh_resource=mesh_resource): + # Test gradients w.r.t. all inputs + te_grad_func = jax.jit( + jax.value_and_grad(self._te_sum_dense, argnums=(0, 1, 2)), + static_argnames=("contracting_dims",), + ) + jax_grad_func = jax.jit( + jax.value_and_grad(self._jax_sum_dense, argnums=(0, 1, 2)), + static_argnames=("contracting_dims",), + ) + + te_val, te_grads = te_grad_func( + x_sharded, weight_sharded, bias_sharded, contracting_dims + ) + jax_val, jax_grads = jax_grad_func( + x_sharded, weight_sharded, bias_sharded, contracting_dims + ) + + # Compare forward pass + assert_allclose(te_val, jax_val, dtype=dtype) + + # Compare gradients + for i, (te_grad, jax_grad) in enumerate(zip(te_grads, jax_grads)): + te_grad_spec = tuple(i for i in te_grad.sharding.spec if i is not None) + jax_grad_spec = tuple(i for i in jax_grad.sharding.spec if i is not None) + assert te_grad_spec == jax_grad_spec, f"Gradient sharding mismatch at te_grads[{i}]" + gathered_te_grad = jax.lax.with_sharding_constraint( + te_grad, NamedSharding(mesh, PartitionSpec(None)) + ) + gathered_jax_grad = jax.lax.with_sharding_constraint( + jax_grad, NamedSharding(mesh, PartitionSpec(None)) + ) + assert_allclose( + gathered_te_grad, + gathered_jax_grad, + dtype=dtype, + err_msg=f"Gradient mismatch for argument {i}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/jax/test_distributed_helper.py b/tests/jax/test_distributed_helper.py index e74e9aa6f..c9647c13c 100644 --- a/tests/jax/test_distributed_helper.py +++ b/tests/jax/test_distributed_helper.py @@ -9,7 +9,7 @@ from utils import pytest_parametrize_wrapper, is_devices_enough from transformer_engine.jax.sharding import MeshResource, global_mesh_resource -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast def generate_mesh_configs(): @@ -26,10 +26,10 @@ def generate_mesh_configs(): class TestMeshResource(unittest.TestCase): - def test_fp8_autocast_with_mesh_resource(self): + def test_autocast_with_mesh_resource(self): for mesh_config in generate_mesh_configs(): device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = jax.sharding.Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + with mesh, autocast(enabled=False, mesh_resource=mesh_resource): self.assertEqual(mesh_resource, global_mesh_resource()) diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index 03c0d1119..4358a6111 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -16,7 +16,7 @@ from distributed_test_base import compare_ops from utils import pytest_parametrize_wrapper -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.common import recipe from transformer_engine.jax.layernorm import layernorm from transformer_engine.jax.quantize import QuantizerFactory, ScalingMode, is_fp8_available @@ -68,20 +68,19 @@ def generate_collectives_count_ref( self, mesh_resource, ln_type, shape, dtype, mesh_axes, fp8_recipe ): jax_dtype = jax.dtypes.canonicalize_dtype(dtype) + # TODO(Phuong) is_dp_enabled = dp mesh axis size > 1 is_dp_enabled = mesh_resource.dp_resource is not None + is_tpsp_enabled = mesh_resource.tpsp_resource is not None assert ln_type in ["layernorm", "rmsnorm"] - all_reduce_loss_bytes = 4 # 1 * FP32 - # for loss, dgamma and dbeta - # TODO(Jeremy): debug this check because layernorm should always have 2x weights regardless of dp - weight_count = 2 if (ln_type == "layernorm" and "dp" in mesh_axes) else 1 - allreduce_total_bytes = ( - all_reduce_loss_bytes + weight_count * shape[-1] * jax_dtype.itemsize - ) - other_bytes = 0 - if fp8_recipe == recipe.Float8CurrentScaling(): - allreduce_total_bytes += jax_dtype.itemsize # 1 * dtype for the amax reduction + # loss, 1 FP32 + allreduce_total_bytes = 4 if is_dp_enabled else 0 + # dgamma and dbeta + weight_count = 2 if ln_type == "layernorm" else 1 + allreduce_total_bytes += weight_count * shape[-1] * jax_dtype.itemsize return generate_collectives_count( - allreduce=allreduce_total_bytes * int(is_dp_enabled), allgather=0, other=other_bytes + allreduce=allreduce_total_bytes * int(is_dp_enabled or is_tpsp_enabled), + allgather=0, + other=0, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @@ -136,10 +135,13 @@ def ref_func(x, gamma, beta): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) - gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) - beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) + with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + x_named_sharding = NamedSharding(mesh, x_pspec) + g_named_sharding = NamedSharding(mesh, g_pspec) + b_named_sharding = NamedSharding(mesh, b_pspec) + x_ = jax.device_put(x, x_named_sharding) + gamma_ = jax.device_put(gamma, g_named_sharding) + beta_ = jax.device_put(beta, b_named_sharding) with warnings.catch_warnings(record=True) as warns: try: @@ -151,8 +153,11 @@ def ref_func(x, gamma, beta): grad_args=(0, 1, 2), metric_fwd_dtype=q_dtype, metric_bwd_dtype=q_dtype, - in_shardings=(x_pspec, g_pspec, b_pspec), - out_shardings=(None, (x_pspec, g_pspec, b_pspec)), + in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding), + out_shardings=( + None, + (x_named_sharding, g_named_sharding, b_named_sharding), + ), ) except AssertionError as err: # Layernorm should still produce the correct numerical result with @@ -212,9 +217,11 @@ def ref_func(x, gamma): ) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource): - x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) - gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) + with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): + x_named_sharding = NamedSharding(mesh, x_pspec) + g_named_sharding = NamedSharding(mesh, g_pspec) + x_ = jax.device_put(x, x_named_sharding) + gamma_ = jax.device_put(gamma, g_named_sharding) with warnings.catch_warnings(record=True) as warns: try: @@ -226,8 +233,8 @@ def ref_func(x, gamma): grad_args=(0, 1), metric_fwd_dtype=q_dtype, metric_bwd_dtype=q_dtype, - in_shardings=(x_pspec, g_pspec), - out_shardings=(None, (x_pspec, g_pspec)), + in_shardings=(x_named_sharding, g_named_sharding), + out_shardings=(None, (x_named_sharding, g_named_sharding)), ) except AssertionError as err: # RmsNorm should still produce the correct numerical result with diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 0af10d050..c67528f04 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -3,6 +3,7 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import re from typing import Callable, Sequence, Union, Optional import pytest @@ -19,8 +20,12 @@ ) from transformer_engine.common import recipe -from transformer_engine.jax.quantize import is_fp8_available, ScalingMode -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax.quantize import ( + is_fp8_available, + ScalingMode, + get_quantize_config_with_recipe, +) +from transformer_engine.jax import autocast from transformer_engine.jax.flax import LayerNormMLP from transformer_engine.jax.layernorm_mlp import layernorm_mlp from transformer_engine.jax.sharding import ( @@ -35,7 +40,11 @@ W_JOINED_AXES, ) from transformer_engine.jax.sharding import MeshResource -from transformer_engine.jax.quantize import QuantizerFactory +from transformer_engine.jax.quantize import ( + QuantizerFactory, + get_supported_quantization_recipes, + is_scaling_mode_supported, +) from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability from transformer_engine.jax.util import ( @@ -46,18 +55,15 @@ jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() -is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) +is_mxfp8_supported, reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +is_nvfp4_supported, reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) -SUPPORTED_RECIPES = [] -if is_fp8_supported: - SUPPORTED_RECIPES.append(pytest.param(recipe.DelayedScaling(), id="DelayedScaling")) - SUPPORTED_RECIPES.append(pytest.param(recipe.Float8CurrentScaling(), id="CurrentScaling")) -if is_mxfp8_supported: - SUPPORTED_RECIPES.append(pytest.param(recipe.MXFP8BlockScaling(), id="MXFP8BlockScaling")) +SUPPORTED_RECIPES = get_supported_quantization_recipes() +SUPPORTED_RECIPES = [pytest.param(r, id=r.__class__.__name__) for r in SUPPORTED_RECIPES] DTYPES = [jnp.bfloat16, jnp.float16] -INPUT_SHAPE = [[4, 64, 128]] # [batch, seqlen, hidden_in] +INPUT_SHAPE = [[4, 128, 256]] # [batch, seqlen, hidden_in] LAYERNORM_INPUT_AXES = (BATCH_AXES, SEQLEN_TP_AXES, HIDDEN_AXES) DOT_1_INPUT_AXES = (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES) @@ -68,19 +74,47 @@ LN_BIAS_AXES = (W_NO_SHARD_AXES,) BIAS_1_AXES = (W_JOINED_AXES, W_TP_AXES) BIAS_2_AXES = (W_NO_SHARD_AXES,) -INTERMEDIATE = 64 +INTERMEDIATE = 256 # Only test with FSDP and TPSP as DP is not used def generate_fsdp_and_tpsp_configs(): configs = [] + if is_devices_enough(4): + configs.append( + pytest.param( + [ + 4, + (2, 2), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp2_tpsp2", + ) + ) + if is_devices_enough(2): configs.append( - [2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] + pytest.param( + [ + 2, + (1, 2), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp1_tpsp2", + ) ) - if is_devices_enough(4): configs.append( - [4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] + pytest.param( + [ + 2, + (2, 1), + ("fsdp", "tpsp"), + MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp"), + ], + id="fsdp2_tpsp1", + ), ) return configs @@ -122,6 +156,7 @@ def layernorm_fp8_mlp_prim_func( layernorm_type: str = "rmsnorm", activation_type: Sequence[Union[str, Callable]] = ("gelu",), multi_gpus: bool = False, + quantization_recipe: recipe.Recipe = None, ) -> jnp.ndarray: if multi_gpus: @@ -135,7 +170,9 @@ def layernorm_fp8_mlp_prim_func( dot_1_input_axes = dot_2_input_axes = None kernel_1_axes = kernel_2_axes = None - quantizer_sets = QuantizerFactory.create_set(n_quantizer_sets=2) + quantizer_sets = QuantizerFactory.create_set( + n_quantizer_sets=2, fp8_recipe=quantization_recipe + ) # out = ((x * kernel_1) + bias_1) * kernel_2 + bias_2 return jnp.mean( @@ -163,7 +200,7 @@ def _test_layernorm_mlp_grad( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, use_shardy, with_jax_gemm, ): @@ -182,8 +219,10 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + with autocast( + enabled=quantization_recipe is not None, + recipe=quantization_recipe, + mesh_resource=MeshResource(), ): single_jitter = jax.jit( value_and_grad_func, @@ -194,8 +233,10 @@ def _test_layernorm_mlp_grad( # Multi GPUs devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + with mesh, autocast( + enabled=quantization_recipe is not None, + recipe=quantization_recipe, + mesh_resource=mesh_resource, ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) @@ -239,12 +280,14 @@ def _test_layernorm_mlp_grad( if is_hip_extension() and (jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd))): pytest.skip("skip tests with nan/inf single fwd.") - - fwd_test_type = dtype if fp8_recipe is None else jnp_float8_e4m3_type - bwd_test_type = dtype if fp8_recipe is None else jnp_float8_e5m2_type + fwd_test_type = bwd_test_type = dtype + if quantization_recipe is not None: + quantize_config = get_quantize_config_with_recipe(quantization_recipe) + fwd_test_type = quantize_config.FWD_DTYPE + bwd_test_type = quantize_config.BWD_DTYPE if fwd_test_type == jnp.float16 and use_bias: - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) + assert_allclose(multi_fwd, single_fwd, atol=0.04, rtol=1.5) else: assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) @@ -267,13 +310,12 @@ def _test_layernorm_mlp_grad( err_msg=f"multi_grads[{i}] is not close", ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -282,27 +324,28 @@ def test_layernorm_mlp_grad( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, with_jax_gemm, ): + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, use_shardy=False, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, @@ -311,18 +354,18 @@ def test_layernorm_mlp_grad_shardy( use_bias, input_shape, dtype, - fp8_recipe, + quantization_recipe, with_jax_gemm, ): - if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp_grad( mesh_config, activation_type, use_bias, input_shape, dtype, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=True, with_jax_gemm=with_jax_gemm, ) @@ -335,7 +378,7 @@ def _test_layernorm_mlp( input_shape, dtype, use_fp8, - fp8_recipe, + quantization_recipe, use_shardy, with_jax_gemm, ): @@ -344,31 +387,34 @@ def _test_layernorm_mlp( layernorm_type = "rmsnorm" rng = jax.random.PRNGKey(0) - subkeys = jax.random.split(rng, 2) + subkeys = jax.random.split(rng, 3) x = jax.random.normal(subkeys[0], (batch, seqlen, hidden_in), dtype) - init_rngs = {"params": subkeys[1]} + init_rngs = {"params": subkeys[1], "sr_rng": subkeys[2]} with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with autocast( + enabled=use_fp8, recipe=quantization_recipe, mesh_resource=MeshResource() + ): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, intermediate_dim=INTERMEDIATE, activations=activation_type, use_bias=use_bias, + return_layernorm_output=True, ) params_single = ln_mlp_single.init(init_rngs, x, deterministic=True) mlp_out_single, ln_out_single = ln_mlp_single.apply( - params_single, x, deterministic=True + params_single, x, deterministic=True, rngs={"sr_rng": subkeys[2]} ) # Multi GPUs device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast( - enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + with mesh, autocast( + enabled=use_fp8, recipe=quantization_recipe, mesh_resource=mesh_resource ): ln_mlp_sharded = LayerNormMLP( layernorm_type=layernorm_type, @@ -385,22 +431,24 @@ def _test_layernorm_mlp( dot_1_input_axes=DOT_1_INPUT_AXES, dot_2_input_axes=DOT_2_INPUT_AXES, name="mlp", + return_layernorm_output=True, ) params_sharded = ln_mlp_sharded.init(init_rngs, x, deterministic=True) mlp_out_sharded, ln_out_sharded = ln_mlp_sharded.apply( - params_sharded, x, deterministic=True + params_sharded, x, deterministic=True, rngs={"sr_rng": subkeys[2]} ) # Make sure params values are the same assert_tree_like_allclose(params_sharded["params"], params_single["params"]) assert_allclose(ln_out_sharded, ln_out_single, dtype=dtype) + # TODO(Phuong): check if these tols updates are still needed atol = None rtol = None l40_tolerance_update = ( get_min_device_compute_capability() == 89 - and fp8_recipe == recipe.DelayedScaling() and use_fp8 + and quantization_recipe.delayed() and dtype == jnp.float16 and activation_type == ("gelu",) ) @@ -418,9 +466,10 @@ def _test_layernorm_mlp( # within tolerance to the float32 ground truth. jax_triton_gemm_precision_tolerance_update = ( with_jax_gemm - and isinstance(fp8_recipe, recipe.Float8CurrentScaling) - and dtype == jnp.bfloat16 - and activation_type == ("gelu", "linear") + and quantization_recipe is not None + and (quantization_recipe.delayed() or quantization_recipe.float8_current_scaling()) + and dtype in (jnp.bfloat16, jnp.float16) + and activation_type == ("gelu", "linear"), ) if jax_triton_gemm_precision_tolerance_update: atol = 0.08 @@ -444,22 +493,30 @@ def test_layernorm_mlp_layer( input_shape, dtype, use_fp8=False, - fp8_recipe=None, + quantization_recipe=None, use_shardy=False, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + quantization_recipe, + with_jax_gemm, ): + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp( mesh_config, activation_type, @@ -467,7 +524,7 @@ def test_layernorm_mlp_layer_fp8( input_shape, dtype, use_fp8=True, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=False, with_jax_gemm=with_jax_gemm, ) @@ -488,24 +545,30 @@ def test_layernorm_mlp_layer_shardy( input_shape, dtype, use_fp8=False, - fp8_recipe=None, + quantization_recipe=None, use_shardy=True, with_jax_gemm=with_jax_gemm, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("dtype", DTYPES) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_layer_fp8_shardy( - self, mesh_config, activation_type, use_bias, input_shape, dtype, fp8_recipe, with_jax_gemm + self, + mesh_config, + activation_type, + use_bias, + input_shape, + dtype, + quantization_recipe, + with_jax_gemm, ): - if with_jax_gemm and isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - pytest.skip("`jax.nn.scaled_matmul()` does not support the Shardy partitioner.") + if dtype == jnp.float16 and quantization_recipe is not None and quantization_recipe.nvfp4(): + pytest.skip("NVFP4 GEMM + Float16 output is unsupported!") self._test_layernorm_mlp( mesh_config, activation_type, @@ -513,7 +576,7 @@ def test_layernorm_mlp_layer_fp8_shardy( input_shape, dtype, use_fp8=True, - fp8_recipe=fp8_recipe, + quantization_recipe=quantization_recipe, use_shardy=True, with_jax_gemm=with_jax_gemm, ) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index d9eaf314a..f1ae6c9e4 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -15,7 +15,7 @@ from distributed_test_base import generate_configs, generate_collectives_count from distributed_test_base import compare_ops from utils import make_causal_mask, make_self_mask -from transformer_engine.jax import fp8_autocast +from transformer_engine.jax import autocast from transformer_engine.jax.softmax import SoftmaxType, softmax DTYPES = [jnp.float16, jnp.bfloat16] @@ -102,9 +102,11 @@ def impl_test_softmax( collective_count_ref = self.generate_collectives_count_ref() devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) - with mesh, fp8_autocast(mesh_resource=mesh_resource): - x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) - mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec)) + with mesh, autocast(mesh_resource=mesh_resource): + x_named_sharding = NamedSharding(mesh, x_pspec) + mask_named_sharding = NamedSharding(mesh, mask_pspec) + x_ = jax.device_put(x, x_named_sharding) + mask_ = jax.device_put(mask, mask_named_sharding) with warnings.catch_warnings(record=True) as warns: try: @@ -116,8 +118,8 @@ def impl_test_softmax( grad_args=(0,), metric_fwd_dtype=dtype, metric_bwd_dtype=dtype, - in_shardings=(x_pspec, mask_pspec), - out_shardings=(None, (x_pspec,)), + in_shardings=(x_named_sharding, mask_named_sharding), + out_shardings=(None, x_named_sharding), ) except AssertionError as err: # Softmax should still produce the correct numerical result with diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 4d7718cd0..72797f556 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -24,8 +24,8 @@ from jax.sharding import Mesh, NamedSharding, PartitionSpec from jax.typing import ArrayLike, DTypeLike -from transformer_engine.jax import fp8_autocast from transformer_engine.jax.cpp_extensions.misc import is_hip_extension +from transformer_engine.jax import autocast from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.attention import ( AttnBiasType, @@ -35,6 +35,7 @@ reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, + run_length_fill, make_swa_mask, SequenceDescriptor, CPStrategy, @@ -175,15 +176,34 @@ def make_mask( jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape ) - # causal mask - if attn_mask_type.is_causal(): + if attn_mask_type.is_bottom_right(): + run_length_out_q = run_length_fill(segment_ids_q) + run_length_out_kv = run_length_fill(segment_ids_kv) + bottom_right_causal_mask = make_attention_mask( + run_length_out_q - segment_pos_q, + run_length_out_kv - segment_pos_kv, + jnp.less_equal, + ) + inv_mask = combine_masks(bottom_right_causal_mask, inv_mask) + elif attn_mask_type.is_causal(): inv_causal_mask = make_attention_mask( segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) inv_mask = combine_masks(inv_causal_mask, inv_mask) # sliding window mask - inv_swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, jnp.bool_) + inv_swa_mask = ( + make_swa_mask( + segment_pos_q, + segment_pos_kv, + window_size, + dtype=jnp.bool, + segment_ids_q=segment_ids_q, + segment_ids_kv=segment_ids_kv, + ) + if attn_mask_type.is_bottom_right() + else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool_) + ) inv_mask = combine_masks(inv_mask, inv_swa_mask) mask = jnp.logical_not(inv_mask) return mask @@ -342,6 +362,16 @@ def _check_configs(self): if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") + if self.attn_mask_type.is_bottom_right(): + if self.max_seqlen_q > self.max_seqlen_kv: + pytest.skip( + f"BRCM requires cross attn type pattern, i.e.max_seqlen_kv >= max_seqlen_q" + ) + if self.attn_bias_type is not AttnBiasType.NO_BIAS: + pytest.skip(f"cuDNN does not support pre or post scale bias for BRCM") + if self.dropout_prob != 0.0: + pytest.skip(f"cuDNN does not support non-zero dropoouts for BRCM") + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") @@ -352,14 +382,15 @@ def _check_configs(self): pytest.skip( "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) - + # TODO(KshitijLakhani): Set the upper limit for skipping this test when cuDNN adds support if ( - get_device_compute_capability(0) == 100 + get_device_compute_capability(0) >= 100 and self.dropout_prob == 0.1 and self.attn_bias_type is not AttnBiasType.NO_BIAS + and not is_hip_extension() ): pytest.skip( - "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats @@ -551,7 +582,11 @@ def generate_random_segment_ids( self.pad_kv = self.pad_q else: # Force kv_len >= q_len for swa, otherwise, cuDNN kernels don't support - min_segment_len = None if self.window_size is None else self.seqlens_q + min_segment_len = None + if ( + self.window_size is not None or self.attn_mask_type.is_bottom_right() + ): # SWA or BRCM requires kv_len >= q_len + min_segment_len = self.seqlens_q self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, @@ -768,7 +803,7 @@ def test_forward(self): ], ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): primitive_out = customcall_fused_dpa_jit(*customcall_args) primitive_out = self.cp_inverse_reorder_fn(primitive_out) @@ -785,7 +820,7 @@ def test_forward(self): assert_allclose(primitive_valid, reference_valid, dtype=self.dtype) if self.coll_count_ref is not None: - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): target_hlo = ( customcall_fused_dpa_jit.lower(*customcall_args, **kwargs).compile().as_text() ) @@ -885,7 +920,7 @@ def grad_func(func, *args, cp_reverse_out=False, **kwargs): ) ) - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): primitive_out, primitive_dgrad = jitted_primitive(*customcall_args) reference_out, reference_dgrad = jitted_reference(*args) @@ -956,7 +991,7 @@ def check_dqkv(primitive, reference, pad, idx): ) if self.coll_count_ref is not None: - with self.mesh, fp8_autocast(mesh_resource=self.mesh_resource): + with self.mesh, autocast(mesh_resource=self.mesh_resource): target_hlo = jitted_primitive.lower(*customcall_args).compile().as_text() assert_equal_collectives(target_hlo, self.coll_count_ref) @@ -968,6 +1003,9 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(AttnMaskType.PADDING_MASK, id="PADDING"), pytest.param(AttnMaskType.CAUSAL_MASK, id="CAUSAL"), pytest.param(AttnMaskType.PADDING_CAUSAL_MASK, id="PADDING_CAUSAL"), + pytest.param( + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, id="PADDING_CAUSAL_BOTTOM_RIGHT" + ), ], ) @pytest.mark.parametrize( diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py deleted file mode 100644 index e4511e1fe..000000000 --- a/tests/jax/test_helper.py +++ /dev/null @@ -1,148 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import unittest - -import flax -import jax -import jax.numpy as jnp -import numpy as np - -from utils import assert_allclose -from transformer_engine.common.recipe import DelayedScaling, MXFP8BlockScaling, Float8CurrentScaling -from transformer_engine.common.recipe import Format as FP8Format -from transformer_engine.jax import fp8_autocast, get_delayed_scaling -from transformer_engine.jax.quantize import ( - get_quantize_config, - is_fp8_available, - ScalingMode, - update_collections, - TensorSource, -) -from transformer_engine.jax.sharding import MeshResource, global_mesh_resource - -is_fp8_supported, reason = is_fp8_available() -is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING) - - -class TestHelper(unittest.TestCase): - - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_update_collections(self): - original_val = 0.0 - updated_val = 10.0 - - original_state = { - "test1": original_val, - "test2": original_val, - } - updated_state = update_collections({"test1": updated_val}, original_state) - self.assertEqual(updated_state["test1"], updated_val) - self.assertEqual(updated_state["test2"], original_val) - - original_state = flax.core.frozen_dict.FrozenDict(original_state) - updated_state = update_collections({"test1": updated_val}, original_state) - self.assertEqual(updated_state["test1"], updated_val) - self.assertEqual(updated_state["test2"], original_val) - - -class TestFP8Functions(unittest.TestCase): - - def _check_default_state(self): - self.assertFalse(get_quantize_config().is_fp8_enabled()) - - def _compare_delay_scaling(self, ref, test): - self.assertTrue(ref.margin == test.margin) - self.assertTrue(ref.fp8_format == test.fp8_format) - self.assertTrue(ref.amax_history_len == test.amax_history_len) - self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) - - def _compare_current_scaling(self, test): - self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) - for tensor_source in TensorSource: - self.assertEqual( - get_quantize_config().get_scaling_mode(tensor_source), - ScalingMode.CURRENT_TENSOR_SCALING, - ) - - def _compare_mxfp8_scaling(self, test): - self.assertEqual(get_quantize_config().MARGIN, test.margin) - self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) - for tensor_source in TensorSource: - self.assertEqual( - get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING - ) - - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast_delayed_scaling(self): - self._check_default_state() - - with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): - self._check_default_state() - - self._check_default_state() - - ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) - - self._check_default_state() - - ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) - - self._check_default_state() - - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast_current_scaling(self): - self._check_default_state() - - with fp8_autocast( - enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource() - ): - self._check_default_state() - - self._check_default_state() - - cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_current_scaling(cs) - - self._check_default_state() - - cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_current_scaling(cs) - - self._check_default_state() - - @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) - def test_fp8_autocast_mxfp8_block_scaling(self): - self._check_default_state() - - with fp8_autocast( - enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource() - ): - self._check_default_state() - - self._check_default_state() - - bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_mxfp8_scaling(bs) - - self._check_default_state() - - bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(get_quantize_config().is_fp8_enabled()) - self._compare_mxfp8_scaling(bs) - - self._check_default_state() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 6f672ade7..b51d6b213 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -23,12 +23,13 @@ from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.quantize import ( - get_quantize_config, + get_global_quantize_recipe, + get_quantize_config_with_recipe, ScalingMode, is_fp8_available, update_collections, TensorSource, - fp8_autocast, + autocast, ) from transformer_engine.jax.sharding import MeshResource @@ -358,7 +359,7 @@ def test_backward( ref_params, test_params = self._sync_params(ref_params, test_params) - if get_quantize_config().is_fp8_enabled(): + if get_quantize_config_with_recipe(get_global_quantize_recipe()).is_fp8_enabled(): for _ in range(4): _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( inputs, @@ -368,14 +369,24 @@ def test_backward( test_layer, ) if ( - get_quantize_config().get_scaling_mode(TensorSource.X) + get_quantize_config_with_recipe(get_global_quantize_recipe()).get_scaling_mode( + TensorSource.X + ) == ScalingMode.DELAYED_TENSOR_SCALING ): _, updated_quantize_meta = flax.core.pop( - updated_state[0], get_quantize_config().COLLECTION_NAME + updated_state[0], + get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).COLLECTION_NAME, ) test_others = update_collections( - {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others + { + get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).COLLECTION_NAME: updated_quantize_meta + }, + test_others, ) del updated_quantize_meta del updated_state @@ -507,14 +518,14 @@ def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" # Ensure FP8 disabled. # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=False, mesh_resource=MeshResource()): + with autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" # Ensure FP8 disabled. # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=False, mesh_resource=MeshResource()): + with autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -522,7 +533,7 @@ def test_backward(self, data_shape, dtype, attrs): def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -530,7 +541,7 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" # Empty MeshResource is used as we are running on a single device - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with autocast(enabled=True, recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) diff --git a/tests/jax/test_recipe_characteristics.py b/tests/jax/test_recipe_characteristics.py new file mode 100644 index 000000000..5171a6c62 --- /dev/null +++ b/tests/jax/test_recipe_characteristics.py @@ -0,0 +1,443 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import unittest +from functools import partial +from abc import ABC, abstractmethod + +import flax +import jax +import jax.numpy as jnp +import numpy as np +from flax import linen as nn + +from utils import assert_allclose, pytest_parametrize_wrapper +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + MXFP8BlockScaling, + Float8CurrentScaling, + NVFP4BlockScaling, +) +from transformer_engine.common.recipe import Format as FP8Format +from transformer_engine.jax import autocast +from transformer_engine.jax.quantize import ( + get_global_quantize_recipe, + get_quantize_config_with_recipe, + get_supported_quantization_recipes, + is_scaling_mode_supported, + ScalingMode, + update_collections, + TensorSource, + QuantizeLayout, +) +from transformer_engine.jax.quantize.helper import _format2dtypes +from transformer_engine.jax.sharding import MeshResource, global_mesh_resource +from transformer_engine.jax.flax.module import TransformerEngineBase +from transformer_engine.jax import flax as te_flax +import transformer_engine.jax as te + +is_fp8_supported, reason = is_scaling_mode_supported(ScalingMode.DELAYED_TENSOR_SCALING) +is_mxfp8_supported, mxfp8_reason = is_scaling_mode_supported(ScalingMode.MXFP8_1D_SCALING) +is_nvfp4_supported, nvfp4_reason = is_scaling_mode_supported(ScalingMode.NVFP4_1D_SCALING) + +SUPPORTED_RECIPES = get_supported_quantization_recipes() + + +def quantizer_check_vjp(outer_quantizer_set, assertion_func, x): + """Check that the quantizers in the quantizer set are as expected and reconstructed correctly from flattened pytree representations across VJP boundaries.""" + + # Define a function with a custom VJP (vector-Jacobian product) + @partial(jax.custom_vjp, nondiff_argnums=(1,)) + def quantizer_check(inner_quantizer_set, assertion_func, x): + return quantizer_check_fwd(inner_quantizer_set, assertion_func, x)[0] + + def quantizer_check_fwd(inner_quantizer_set, assertion_func, x): + assertion_func(inner_quantizer_set.x, TensorSource.X) + assertion_func(inner_quantizer_set.kernel, TensorSource.KERNEL) + assertion_func(inner_quantizer_set.dgrad, TensorSource.DGRAD) + return x, (inner_quantizer_set,) + + def quantizer_check_bwd(assertion_func, ctx, g): + (inner_quantizer_set,) = ctx + return (inner_quantizer_set, g) + + quantizer_check.defvjp(quantizer_check_fwd, quantizer_check_bwd) + return quantizer_check(outer_quantizer_set, assertion_func, x) + + +class TestModule(TransformerEngineBase): + """A simple module to test quantizer creation and reconstruction across VJP boundaries.""" + + # Signature: (quantizer: Quantizer, tensor_source: TensorSource) -> None + assertion_func: callable + direct_recipe: Recipe + + @nn.compact + def __call__(self, x): + quantizer_set = self.generate_quantizer_set(fp8_recipe=self.direct_recipe) + return quantizer_check_vjp(quantizer_set, self.assertion_func, x) + + +class TestHelper(unittest.TestCase): + + @unittest.skipIf(not is_fp8_supported, reason=reason) + def test_update_collections(self): + original_val = 0.0 + updated_val = 10.0 + + original_state = { + "test1": original_val, + "test2": original_val, + } + updated_state = update_collections({"test1": updated_val}, original_state) + self.assertEqual(updated_state["test1"], updated_val) + self.assertEqual(updated_state["test2"], original_val) + + original_state = flax.core.frozen_dict.FrozenDict(original_state) + updated_state = update_collections({"test1": updated_val}, original_state) + self.assertEqual(updated_state["test1"], updated_val) + self.assertEqual(updated_state["test2"], original_val) + + +def assert_fp8_format(quantizer, tensor_source, fp8_format): + if fp8_format == FP8Format.HYBRID: + if tensor_source == TensorSource.DGRAD: + assert quantizer.q_dtype == jnp.float8_e5m2 + else: + assert quantizer.q_dtype == jnp.float8_e4m3fn + elif fp8_format == FP8Format.E4M3: + assert quantizer.q_dtype == jnp.float8_e4m3fn + else: + raise ValueError(f"Unsupported FP8 format: {fp8_format}") + + +class RecipeAssertionBase(ABC): + """Base class for defining recipe assertions.""" + + @abstractmethod + def assert_context(self, ref_recipe, quantize_config): + """Asserts that the quantize_config matches the expected properties from the reference recipe when the recipe is used with an autocast context. + + Args: + ref_recipe: The reference quantization recipe. + quantize_config: The quantization configuration to be checked. + """ + pass + + @abstractmethod + def assert_quantizers(self, ref_recipe, quantizer, tensor_source): + """Asserts that the quantizer matches the expected properties from the reference recipe. The quantizers are created in a small test Flax module TestModule and passed through a VJP boundary to ensure correct reconstruction. + + Args: + ref_recipe: The reference quantization recipe. + quantizer: The quantizer to be checked. + tensor_source: The source of the tensor (e.g., KERNEL, X, DGRAD). + """ + pass + + +class DelayedScalingRecipeAssertion(RecipeAssertionBase): + + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.MARGIN == ref_recipe.margin + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1] + assert quantize_config.AMAX_HISTORY_LEN == ref_recipe.amax_history_len + assert quantize_config.AMAX_COMPUTE_ALGO.value == ref_recipe.amax_compute_algo + for tensor_source in TensorSource: + assert ( + quantize_config.get_scaling_mode(tensor_source) + == ScalingMode.DELAYED_TENSOR_SCALING + ) + + def assert_quantizers(self, ref_recipe: DelayedScaling, quantizer, tensor_source): + assert quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING + assert quantizer.margin == ref_recipe.margin + assert quantizer.amax_compute_algo.value == ref_recipe.amax_compute_algo + assert quantizer.amax_history.shape == (ref_recipe.amax_history_len,) + assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format) + + +class CurrentScalingRecipeAssertion(RecipeAssertionBase): + + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1] + for tensor_source in TensorSource: + assert ( + quantize_config.get_scaling_mode(tensor_source) + == ScalingMode.CURRENT_TENSOR_SCALING + ) + + def assert_quantizers(self, ref_recipe: Float8CurrentScaling, quantizer, tensor_source): + assert quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format) + + +class MXFP8RecipeAssertion(RecipeAssertionBase): + + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp8_format)[1] + for tensor_source in TensorSource: + assert quantize_config.get_scaling_mode(tensor_source) == ScalingMode.MXFP8_1D_SCALING + + def assert_quantizers(self, ref_recipe: MXFP8BlockScaling, quantizer, tensor_source): + assert quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING + assert_fp8_format(quantizer, tensor_source, ref_recipe.fp8_format) + + +class NVFP4RecipeAssertion(RecipeAssertionBase): + + def assert_context(self, ref_recipe, quantize_config): + assert quantize_config.FWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[0] + assert quantize_config.BWD_DTYPE == _format2dtypes(ref_recipe.fp4_format)[1] + for tensor_source in TensorSource: + target_scaling_mode = ( + ScalingMode.NVFP4_2D_SCALING + if (not ref_recipe.disable_2d_quantization) and tensor_source == TensorSource.KERNEL + else ScalingMode.NVFP4_1D_SCALING + ) + assert quantize_config.get_scaling_mode(tensor_source) == target_scaling_mode + assert quantize_config.DISABLE_STOCHASTIC_ROUNDING == ref_recipe.disable_stochastic_rounding + assert quantize_config.DISABLE_RHT == ref_recipe.disable_rht + assert quantize_config.DISABLE_2D_QUANTIZATION == ref_recipe.disable_2d_quantization + + def assert_quantizers(self, ref_recipe: NVFP4BlockScaling, quantizer, tensor_source): + if tensor_source == TensorSource.KERNEL and not ref_recipe.disable_2d_quantization: + assert quantizer.scaling_mode == ScalingMode.NVFP4_2D_SCALING + else: + assert quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + + if ref_recipe.disable_stochastic_rounding or tensor_source != TensorSource.DGRAD: + assert quantizer.stochastic_rounding_rng_state is None + else: + assert quantizer.stochastic_rounding_rng_state is not None + + expected_rht = ( + quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + and quantizer.q_layout in {QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE} + and not ref_recipe.disable_rht + ) + assert quantizer.use_rht == expected_rht + + +class TestFP8Functions(unittest.TestCase): + + def _check_default_state(self): + self.assertEqual(get_global_quantize_recipe(), None) + + def _test_recipe(self, quantization_recipe: Recipe, cls: RecipeAssertionBase): + """Tests a quantization recipe by verifying its behavior in both autocast and direct application contexts.""" + assert_context_func = cls().assert_context + assert_quantizer_func = partial(cls().assert_quantizers, quantization_recipe) + self._test_recipe_autocast(quantization_recipe, assert_context_func, assert_quantizer_func) + self._test_recipe_direct(quantization_recipe, assert_quantizer_func) + + def _test_recipe_autocast( + self, quantization_recipe, assert_context_func, assert_quantizer_func + ): + """Tests a quantization recipe within an autocast context by verifying the quantize config and quantizers in a test module.""" + self._check_default_state() + with autocast(enabled=False, recipe=quantization_recipe, mesh_resource=MeshResource()): + self._check_default_state() + with autocast(enabled=True, recipe=quantization_recipe, mesh_resource=MeshResource()): + quantize_config = self._get_global_quantize_config() + assert_context_func(quantization_recipe, quantize_config) + self._test_quantizer_in_model(assert_quantizer_func) + self._check_default_state() + + def _test_recipe_direct(self, quantization_recipe, assert_quantizer_func): + """Tests a quantization recipe by directly passing it to a test module and verifying the quantizers.""" + self._check_default_state() + self._test_quantizer_in_model(assert_quantizer_func, direct_recipe=quantization_recipe) + self._check_default_state() + + def _test_quantizer_in_model(self, assert_quantizer_func, direct_recipe=None): + """Tests that the quantizers created in a test module match the expected properties by passing them through a VJP boundary. + + Args: + assert_quantizer_func: A function that asserts the properties of the quantizers. The function signature is (quantizer: Quantizer, tensor_source: TensorSource) -> None. + direct_recipe: An optional quantization recipe to be passed directly to the test module. This is an alternative API to using autocast contexts. + """ + x = jnp.ones((), dtype=jnp.float32) + test_module = TestModule(assertion_func=assert_quantizer_func, direct_recipe=direct_recipe) + param_key, sr_key = jax.random.split(jax.random.PRNGKey(0)) + rngs = {"params": param_key, "sr_rng": sr_key} + variables = test_module.init(rngs, x) + + jax.jit(jax.value_and_grad(test_module.apply), static_argnums=(2,))(variables, x, rngs=rngs) + + def _get_global_quantize_config(self): + quantization_recipe = get_global_quantize_recipe() + assert quantization_recipe is not None, "No global quantization recipe set" + quantize_config = get_quantize_config_with_recipe(quantization_recipe) + assert ( + quantize_config.is_fp8_enabled() + ), "Quantization not enabled in global quantize config" + return quantize_config + + @unittest.skipIf(not is_fp8_supported, reason=reason) + def test_autocast_delayed_scaling(self): + self._test_recipe( + quantization_recipe=DelayedScaling(), + cls=DelayedScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=DelayedScaling( + margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1 + ), + cls=DelayedScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=DelayedScaling( + margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1 + ), + cls=DelayedScalingRecipeAssertion, + ) + + @unittest.skipIf(not is_fp8_supported, reason=reason) + def test_autocast_current_scaling(self): + self._test_recipe( + quantization_recipe=Float8CurrentScaling(), + cls=CurrentScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3), + cls=CurrentScalingRecipeAssertion, + ) + self._test_recipe( + quantization_recipe=Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID), + cls=CurrentScalingRecipeAssertion, + ) + + @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + def test_autocast_mxfp8_block_scaling(self): + self._test_recipe( + quantization_recipe=MXFP8BlockScaling(), + cls=MXFP8RecipeAssertion, + ) + + @unittest.skipIf(not is_nvfp4_supported, reason=nvfp4_reason) + def test_autocast_nvfp4_block_scaling(self): + self._test_recipe( + quantization_recipe=NVFP4BlockScaling(), + cls=NVFP4RecipeAssertion, + ) + self._test_recipe( + quantization_recipe=NVFP4BlockScaling( + disable_stochastic_rounding=True, + disable_rht=True, + disable_2d_quantization=True, + ), + cls=NVFP4RecipeAssertion, + ) + + +class TestJaxprAndHlo: + """Tests to verify Jaxpr and/or HLO of compiled modules apply expected recipe functionality and optimizations.""" + + def _generate_jaxpr_for_layernorm_mlp_fwd_bwd(self, quantization_recipe, ln_mlp_kwargs=None): + """Generates the jaxpr for a forward and backward pass of LayerNormMLP under the given quantization recipe.""" + ln_mlp_kwargs = ln_mlp_kwargs or {} + with te.autocast(enabled=True, recipe=quantization_recipe, mesh_resource=te.MeshResource()): + model = te_flax.LayerNormMLP( + layernorm_type="rmsnorm", + return_layernorm_output=False, + intermediate_dropout_rate=0.0, + dtype=jnp.bfloat16, + **ln_mlp_kwargs, + ) + + var_collect = model.init( + jax.random.PRNGKey(0), + jnp.ones((128, 128), dtype=jnp.bfloat16), + ) + + def loss_fn(x, rngs): + return jnp.mean(model.apply(var_collect, x, rngs=rngs)[0]) + + x = jax.random.normal(jax.random.PRNGKey(0), (128, 128), dtype=jnp.bfloat16) + rngs = {"sr_rng": jax.random.PRNGKey(1), "dropout": jax.random.PRNGKey(2)} + return jax.make_jaxpr(jax.value_and_grad(loss_fn))(x, rngs=rngs) + + @pytest_parametrize_wrapper( + "quantization_recipe", + [ + quantization_recipe + for quantization_recipe in SUPPORTED_RECIPES + if isinstance(quantization_recipe, NVFP4BlockScaling) + ], + ) + def test_layernorm_mlp_reuses_amax_nvfp4(self, quantization_recipe): + """Tests that layernorm_mlp reuses the amax computed in layernorm and the activation and does not recompute it during quantizaton.""" + + jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe) + + rht_amax_eqns = [ + eqn for eqn in jaxpr.jaxpr.eqns if eqn.primitive.name == "te_rht_amax_ffi_wrapper" + ] + + assert len(rht_amax_eqns) == 4, f"Expected 4 rht_amax_eqns, got {len(rht_amax_eqns)}" + + def assert_param(index, tensor_name, expected_value: bool): + if expected_value: + assert rht_amax_eqns[index].params["produce_regular_amax"] == True, ( + f"Expected produce_regular_amax for {tensor_name} to be True, indicating no" + " reuse of amax as this tensor does not have a previous operation to fuse" + " with" + ) + else: + assert rht_amax_eqns[index].params["produce_regular_amax"] == False, ( + f"Expected produce_regular_amax for {tensor_name} to be False, indicating" + " reuse of amax" + ) + + assert_param(0, "fwd ln+q", False) + assert_param(1, "fwd act+q", False) + # No previous op before incoming dgrad in the backward so amax is not reused + assert_param(2, "bwd dgrad", True) + assert_param(3, "bwd dact+q", False) + + @pytest_parametrize_wrapper("quantization_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper( + "quantization_checkpoint_name", + [None, "quantization", "some_arbitrary_user_checkpoint_name"], + ) + def test_recipe_supports_quantization_checkpointing( + self, quantization_recipe, quantization_checkpoint_name + ): + """Tests that all supported quantization recipes correctly use checkpoint_name.""" + + kwargs = { + "quantization_checkpoint_name": quantization_checkpoint_name, + } + jaxpr = self._generate_jaxpr_for_layernorm_mlp_fwd_bwd(quantization_recipe, kwargs) + + checkpoint_name_eqns = [ + eqn + for eqn in jaxpr.jaxpr.eqns + if eqn.primitive.name == "name" and eqn.params["name"] == quantization_checkpoint_name + ] + + if quantization_checkpoint_name is None: + assert len(checkpoint_name_eqns) == 0, ( + "Expected 0 checkpoint_name eqns when quantization_checkpoint_name is None, got" + f" {len(checkpoint_name_eqns)}" + ) + return + + # 12 checkpointed values: + # - Fwd pass: + # - Input RMSNorm+Q -> 3 possible output tensors that will be used in the backward + # - Kernel Q -> 3 possible output tensors that will be used in the backward + # - Input Activation+Q -> 3 possible output tensors that will be used in the backward + # - Kernel Q -> 3 possible output tensors that will be used in the backward + expected_checkpoint_eqn_count = 12 + + assert len(checkpoint_name_eqns) == expected_checkpoint_eqn_count, ( + f"Expected {expected_checkpoint_eqn_count} checkpoint_name eqns when" + f" quantization_checkpoint_name is set, got {len(checkpoint_name_eqns)}" + ) diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 56d5df8e3..60b01348f 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -367,9 +367,9 @@ class MlpBlock(nn.Module): transpose_batch_sequence: bool intermediate_dim: int = 2048 - activations: Sequence[Union[str, Callable]] = ("relu",) + activations: Sequence[Union[str, Callable]] = ("gelu",) kernel_init: Initializer = None - intermediate_dropout_rate: float = 0.1 + intermediate_dropout_rate: float = 0.0 intermediate_dropout_dims: Sequence[int] = () use_bias: bool = False dtype: Any = jnp.float32 @@ -1038,14 +1038,14 @@ class EncoderLayer(nn.Module): hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 + intermediate_dropout: float = 0.0 intermediate_dropout_dims: Sequence[int] = () transpose_batch_sequence: bool = True float32_attention_logits: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 - mlp_activations: Sequence[str] = ("relu",) + mlp_activations: Sequence[str] = ("gelu",) use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False @@ -1202,14 +1202,14 @@ class DecoderLayer(nn.Module): hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 + intermediate_dropout: float = 0.0 intermediate_dropout_dims: Sequence[int] = () transpose_batch_sequence: bool = True float32_attention_logits: bool = False scale_attn_logits: bool = False scaled_query_init: bool = True mlp_dim: int = 2048 - mlp_activations: Sequence[str] = ("relu",) + mlp_activations: Sequence[str] = ("gelu",) use_bias: bool = False dtype: Any = jnp.float32 apply_residual_connection_post_layernorm: bool = False @@ -1547,6 +1547,12 @@ def dtype_tols( rtol = eps_relaxed if atol is None: atol = max(ulp, eps_relaxed) + + # Manually set tols for nvfp4 + if dtype == jnp.float4_e2m1fn: + atol = 0.05 + rtol = 0.1 + return {"rtol": rtol, "atol": atol} diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 10bb066a4..9e59f4f6a 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,99 +10,34 @@ from contextlib import nullcontext import torch import torch.distributed as dist +import warnings + from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.attention import DotProductAttention + from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_cu_seqlens_on_cp_rank, ) +from transformer_engine.pytorch.attention.dot_product_attention.utils import combine_and_quantize import transformer_engine_torch as tex from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn -from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer -from transformer_engine.common.recipe import DelayedScaling -import warnings +from transformer_engine.pytorch import ( + autocast, + DotProductAttention, + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling +from utils import ModelConfig, compare_and_assert dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} -def run_dpa_with_cp( - dtype="bf16", - model=None, - qkv_format="bshd", - kernel_backend="FlashAttention", - cp_comm_type="p2p", - fp8_mha=False, +def generate_input_shapes( + qkv_format: str, + config: ModelConfig, + world_size: int, + kernel_backend: str, ): - """Test DotProductAttention module with context parallelism""" - - # args are passed as strings - fp8_mha = fp8_mha == "True" - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - if kernel_backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - config = model_configs_flash_attn[model] - if kernel_backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - config = model_configs_fused_attn[model] - - assert config.attn_mask_type in [ - "causal", - "no_mask", - ], f"{config.attn_mask_type} is an unsupported attention mask type!" - if qkv_format == "thd": - if "causal" in config.attn_mask_type: - config.attn_mask_type = "padding_causal" - else: - config.attn_mask_type = "padding" - - rank = int(os.getenv("RANK", "0")) - world_size = int(os.getenv("WORLD_SIZE", "1")) - - if dist.is_initialized(): - world_size = dist.get_world_size() - rank = dist.get_rank() - else: - device_count = torch.cuda.device_count() - device = rank % device_count - torch.cuda.set_device(device) - - print(f"[INFO] world_size:{world_size}, rank:{rank}") - - dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) - - # create flash attn comm group for CP - cp_comm_ranks = range(world_size) - assert rank in cp_comm_ranks - cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") - if cp_comm_type == "a2a+p2p": - assert ( - world_size % 2 == 0 - ), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!" - cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] - cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] - cp_comm_sub_groups = [] - for sub_ranks in cp_comm_sub_ranks: - sub_group = dist.new_group(sub_ranks, backend="nccl") - if rank in sub_ranks: - cp_comm_sub_groups.append(sub_group) - - if dtype == "fp8": - fp8_recipe = DelayedScaling(fp8_dpa=True, fp8_mha=fp8_mha) - - # instantiate core attn module - core_attn = DotProductAttention( - config.num_heads, - (config.head_dim_qk, config.head_dim_v), - num_gqa_groups=config.num_gqa_groups, - attention_dropout=config.dropout_p, - qkv_format=qkv_format, - attn_mask_type=config.attn_mask_type, - window_size=config.window_size, - ) - core_attn = core_attn.cuda() - - # create flash attn inputs if qkv_format == "bshd": q_input_shape = ( config.batch_size, @@ -197,35 +132,194 @@ def run_dpa_with_cp( cu_seqlens_kv = cu_seqlens_q cu_seqlens_kv_padded = cu_seqlens_q_padded else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - - q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda() - k = torch.randn(k_input_shape, dtype=dtypes[dtype]).cuda() - v = torch.randn(v_input_shape, dtype=dtypes[dtype]).cuda() - dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda() - dout_quantizer = Float8Quantizer( - fp8_dtype=tex.DType.kFloat8E5M2, - scale=torch.tensor([1], dtype=torch.float32).cuda(), - amax=torch.tensor([0], dtype=torch.float32).cuda(), + assert False, f"{qkv_format=} is not supported!" + + return ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ) - # create flash attention bias + +def get_tols(config, dtype): + if dtype == "bf16": + if config.num_heads == config.num_gqa_groups: + atol = 2.5e-2 + rtol = 2.5e-2 + else: + atol = 3.5e-2 + rtol = 3.5e-2 + rmse_tol = 0.01 + elif dtype == "fp16": + atol = 5e-3 + rtol = 5e-3 + rmse_tol = 0.01 + elif dtype == "fp8": + atol = 5e-1 + rtol = 5e-1 + rmse_tol = 0.15 + else: + assert False, f"{dtype=} is not supported!" + + return atol, rtol, rmse_tol + + +def run_dpa_with_cp( + dtype="bf16", + model=None, + qkv_format="bshd", + kernel_backend="FlashAttention", + cp_comm_type="p2p", + fp8_bwd="True", + fp8_dpa="False", + fp8_mha="False", + scaling_mode="delayed", + f16_O="False", + log_level=logging.WARNING, +): + """Test DotProductAttention module with context parallelism""" + logging.root.setLevel(log_level) + + # set up environment variables and config + fp8_bwd = fp8_bwd == "True" and dtype == "fp8" + os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_bwd else "0" + fp8_dpa = fp8_dpa == "True" and dtype == "fp8" + fp8_mha = fp8_mha == "True" and dtype == "fp8" + f16_O = dtype == "fp8" and scaling_mode == "current" and f16_O == "True" + os.environ["NVTE_DPA_FP8CS_O_in_F16"] = "1" if f16_O else "0" + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + if kernel_backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + config = model_configs_flash_attn[model] + if kernel_backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + config = model_configs_fused_attn[model] + assert config.attn_mask_type in [ + "causal", + "no_mask", + ], f"{config.attn_mask_type=} is not supported!" + if qkv_format == "thd": + if "causal" in config.attn_mask_type: + config.attn_mask_type = "padding_causal" + else: + config.attn_mask_type = "padding" + + # set up distributed group + rank = int(os.getenv("RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + if dist.is_initialized(): + world_size = dist.get_world_size() + rank = dist.get_rank() + else: + device_count = torch.cuda.device_count() + device = rank % device_count + torch.cuda.set_device(device) + logging.info(f"[Rank {rank}] Setup: world_size {world_size}") + dist.init_process_group(backend="nccl", world_size=world_size, rank=rank) + + # set up communication group for CP + cp_comm_ranks = range(world_size) + assert rank in cp_comm_ranks + cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl") + if cp_comm_type == "a2a+p2p": + assert world_size % 2 == 0, ( + "{cp_comm_type=} requires world_size % 2 = 0 as it assumes the a2a level has cp_size" + " = 2." + ) + cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)] + cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)] + cp_comm_sub_groups = [] + for sub_ranks in cp_comm_sub_ranks: + sub_group = dist.new_group(sub_ranks, backend="nccl") + if rank in sub_ranks: + cp_comm_sub_groups.append(sub_group) + + if dtype == "fp8": + if scaling_mode == "delayed": + fp8_recipe = DelayedScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + if scaling_mode == "current": + fp8_recipe = Float8CurrentScaling(fp8_dpa=fp8_dpa, fp8_mha=fp8_mha) + + # instantiate attention module + core_attn = DotProductAttention( + config.num_heads, + (config.head_dim_qk, config.head_dim_v), + num_gqa_groups=config.num_gqa_groups, + attention_dropout=config.dropout_p, + qkv_format=qkv_format, + attn_mask_type=config.attn_mask_type, + window_size=config.window_size, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, + ).cuda() + if config.softmax_type != "vanilla": + core_attn.softmax_offset.requires_grad = True + + # generate attention inputs + ( + q_input_shape, + k_input_shape, + v_input_shape, + attn_output_shape, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() + dout_orig = torch.clamp( + torch.randn(attn_output_shape, dtype=dtypes[dtype]), min=-1, max=1 + ).cuda() + if scaling_mode == "delayed": + qkv_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + dout_quantizer = Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + scale=torch.tensor([1], dtype=torch.float32).cuda(), + amax=torch.tensor([0], dtype=torch.float32).cuda(), + ) + if scaling_mode == "current": + qkv_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + dout_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E5M2, + device="cuda", + ) + qkv_layout = "_".join([qkv_format] * 3) + q, k, v, dout = [x.clone().detach() for x in [q_orig, k_orig, v_orig, dout_orig]] + if fp8_mha: + q, k, v = combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer) + for x in [q, k, v]: + x.requires_grad = True + if config.attn_bias_type not in ["no_bias", "alibi"]: attn_bias_shape = (1, 1, config.max_seqlen_q, config.max_seqlen_kv) bias = torch.randn(*attn_bias_shape, dtype=dtypes[dtype]).cuda() else: bias = None - # run core_attn without CP - for x in [q, k, v]: - x.requires_grad = True - + ############ run without CP ############ + logging.info(f"[Rank {rank}] Run without context parallelism") if dtype == "fp8": - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() - + max_logit = None with fp8_context: + # q, k, v, out in FP8; dout in F16 out = core_attn( q, k, @@ -238,16 +332,27 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), + fp8_output=fp8_mha, ) - if fp8_mha: + if config.return_max_logit: + out, max_logit = out + if fp8_bwd and fp8_mha: dout_fp8 = dout_quantizer(dout) out.backward(dout_fp8) else: out.backward(dout) + dq, dk, dv = q.grad, k.grad, v.grad + d_softmax_offset = None + if config.softmax_type != "vanilla": + d_softmax_offset = core_attn.softmax_offset.grad + + ############ run with CP ############ + logging.info(f"[Rank {rank}] Run with context parallelism") - # run core_attn wit CP + # set up inputs q_, k_, v_, dout_, *rest = [ - x.clone().detach() for x in [q, k, v, dout] + ([] if bias is None else [bias]) + x.clone().detach() + for x in [q_orig, k_orig, v_orig, dout_orig] + ([] if bias is None else [bias]) ] bias_ = rest[0] if len(rest) else None if qkv_format == "bshd" or qkv_format == "sbhd": @@ -277,6 +382,14 @@ def run_dpa_with_cp( k_, v_ = [x.index_select(0, seq_idx_kv) for x in [k_, v_]] else: assert False, f"{qkv_format} is an unsupported qkv_format!" + q_, k_, v_, dout_ = [x.contiguous() for x in [q_, k_, v_, dout_]] + if scaling_mode == "delayed": + qkv_quantizer.scale.fill_(1.0) + qkv_quantizer.amax.fill_(0.0) + dout_quantizer.scale.fill_(1.0) + dout_quantizer.amax.fill_(0.0) + if fp8_mha: + q_, k_, v_ = combine_and_quantize(qkv_layout, q_, k_, v_, qkv_quantizer) q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]] if bias_ is not None: bias_ = bias_.view( @@ -284,20 +397,26 @@ def run_dpa_with_cp( ) bias_ = bias_.index_select(2, seq_idx) bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1]) + # set up environment core_attn.set_context_parallel_group( cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type, ) - + if config.softmax_type != "vanilla": + core_attn.softmax_offset.grad.zero_() if dtype == "fp8": - core_attn.reset_fp8_meta_tensors() - fp8_context = fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, fp8_group=cp_comm_group) + core_attn.fp8_initialized = False + core_attn.fp8_meta_tensors_initialized = False + fp8_context = autocast(enabled=True, recipe=fp8_recipe, amax_reduction_group=cp_comm_group) else: fp8_context = nullcontext() + # run attention + max_logit_ = None with fp8_context: + # q, k, v, out in FP8; dout in F16 out_ = core_attn( q_, k_, @@ -310,24 +429,35 @@ def run_dpa_with_cp( cu_seqlens_kv_padded=( None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1] ), + fp8_output=fp8_mha, ) - if fp8_mha: + if config.return_max_logit: + out_, max_logit_ = out_ + if fp8_bwd and fp8_mha: dout_fp8_ = dout_quantizer(dout_) out_.backward(dout_fp8_) else: out_.backward(dout_) + dq_, dk_, dv_ = q_.grad, k_.grad, v_.grad + d_softmax_offset_ = None + if config.softmax_type != "vanilla": + d_softmax_offset_ = core_attn.softmax_offset.grad.clone() + # get outputs + tensors = [out, dq, dk, dv, out_, dq_, dk_, dv_] if fp8_mha: - assert isinstance(out, Float8Tensor) - assert isinstance(out_, Float8Tensor) - out = out.dequantize() - out_ = out_.dequantize() - - for x in [out_, q_.grad, k_.grad, v_.grad]: - assert torch.all(~torch.isnan(x)) - assert torch.all(~torch.isinf(x)) - - # compare results with and without CP + tensors_to_deq = [out, out_] if not fp8_bwd else tensors + for i, tensor in enumerate(tensors_to_deq): + tensors_to_deq[i] = tensor.dequantize() + if not fp8_bwd: + tensors[0], tensors[4] = tensors_to_deq + i = 0 + for tensor in tensors[4:]: + assert torch.all(~torch.isnan(tensor)) + assert torch.all(~torch.isinf(tensor)) + out, dq, dk, dv, out_, dq_, dk_, dv_ = tensors + + ############ compare results between CP and no-CP ############ if qkv_format == "bshd" or qkv_format == "sbhd": dq, dk, dv, out = [ x.view( @@ -336,17 +466,17 @@ def run_dpa_with_cp( x.shape[seq_dim] // (2 * world_size), *x.shape[(seq_dim + 1) :], ) - for x in [q.grad, k.grad, v.grad, out] + for x in [dq, dk, dv, out] ] dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]] dq_, dk_, dv_, out_ = [ x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim] // 2, *x.shape[(seq_dim + 1) :]) - for x in [q_.grad, k_.grad, v_.grad, out_] + for x in [dq_, dk_, dv_, out_] ] elif qkv_format == "thd": - dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]] - dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]] - dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_] + dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] + dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] + dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True @@ -389,56 +519,70 @@ def run_dpa_with_cp( ).item() == 0 ) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" - - if dtype == "bf16": - if config.num_heads == config.num_gqa_groups: - tols = dict(atol=2.5e-2, rtol=2.5e-2) - else: - tols = dict(atol=3.5e-2, rtol=3.5e-2) - elif dtype == "fp16": - tols = dict(atol=5e-3, rtol=5e-3) - elif dtype == "fp8": - tols = dict(atol=5e-1, rtol=5e-1) - rmse_tol = 0.1 - else: - assert False, f"{dtype} is an unsupported dtype!" - def _rmse(a, b): - return torch.sqrt((a - b).square().mean()).item() - - def _error(a, b): - if dtype != "fp8": - torch.testing.assert_close(a, b, **tols) - else: - try: - torch.testing.assert_close(a, b, **tols) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert ( - rmse < rmse_tol * rmse_range - ), "RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - - if qkv_format == "bshd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[:, 0], b[:, 0]) - _error(a[:, 1], b[:, 1]) - elif qkv_format == "sbhd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a[0], b[0]) - _error(a[1], b[1]) - elif qkv_format == "thd": - for a, b in zip([out_, dq_, dk_, dv_], [out, dq, dk, dv]): - _error(a, b) - else: - assert False, f"{qkv_format} is an unsupported qkv_format!" + atol, rtol, rmse_tol = get_tols(config, dtype) + tensors_cp = [out_, dq_, dk_, dv_, d_softmax_offset_, max_logit_] + tensors_no_cp = [out, dq, dk, dv, d_softmax_offset, max_logit] + names = ["out", "dq", "dk", "dv", "d_softmax_offset", "max_logit"] + names_cp = [x + "_cp" for x in names] + names_no_cp = [x + "_no_cp" for x in names] + is_fp8 = dtype == "fp8" + for i, t in enumerate(tensors_no_cp): + if t is not None: + if "softmax_offset" not in names[i] and "max_logit" not in names[i]: + if qkv_format == "bshd": + compare_and_assert( + t[:, 0], + tensors_cp[i][:, 0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[:, 1], + tensors_cp[i][:, 1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "sbhd": + compare_and_assert( + t[0], + tensors_cp[i][0], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + compare_and_assert( + t[1], + tensors_cp[i][1], + names_no_cp[i], + names_cp[i], + atol, + rtol, + rmse_tol, + is_fp8, + ) + elif qkv_format == "thd": + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + else: + compare_and_assert( + t, tensors_cp[i], names_no_cp[i], names_cp[i], atol, rtol, rmse_tol, is_fp8 + ) + logging.info(f"[Rank {rank}] CP vs no-CP: {names[i]} matches") + # destroy distribution group dist.destroy_process_group() diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index a5128653e..c0cf64803 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -4,7 +4,6 @@ # # See LICENSE for license information. import logging -import math import os import sys import pathlib @@ -15,13 +14,22 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, get_fp8_te_dtype from transformer_engine.common import recipe -from transformer_engine.pytorch import TransformerLayer, fp8_autocast, fp8_model_init -from transformer_engine.pytorch.attention.dot_product_attention import ( +from transformer_engine.pytorch import ( + TransformerLayer, + autocast, + quantized_model_init, DotProductAttention, + MultiheadAttention, + get_device_compute_capability, + Quantizer, + is_fp8_available, + is_bf16_available, +) +from transformer_engine.pytorch.attention.dot_product_attention import ( _attention_backends, ) -from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, check_set_window_size, @@ -34,17 +42,14 @@ fused_attn_fwd, ) from transformer_engine.pytorch.distributed import CudaRNGStatesTracker -import transformer_engine.pytorch.fp8 as fp8 from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.utils import ( - get_device_compute_capability, init_method_normal, scaled_init_method_normal, - is_bf16_compatible, ) from transformer_engine.pytorch.utils import get_cudnn_version import transformer_engine_torch as tex -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( Quantizer, prepare_for_saving, restore_from_saved, @@ -54,23 +59,33 @@ sys.path.append(str(_current_file.parent.parent)) from utils import ( reset_rng_states, + compare_and_assert, ModelConfig, dtype_tols, get_available_attention_backends, ) -# Only run FP8 tests on H100 -fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() +# Check if hardware supports FP8 attention. +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +fp8_attn_available, reason_for_no_fp8_attn = fp8_available, reason_for_no_fp8 +device_compute_capability = get_device_compute_capability() +if fp8_available and (device_compute_capability < (9, 0) or device_compute_capability >= (12, 0)): + fp8_attn_available = False + reason_for_no_fp8_attn = ( + "FP8 attention is not supported for compute capability =" + f" sm{device_compute_capability[0] * 10 + device_compute_capability[1]}" + ) +# Reset RNG seed and states seed = 1234 -# Reset RNG states reset_rng_states() +# Reset FP8 global state manager @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield - fp8.FP8GlobalStateManager.reset() + FP8GlobalStateManager.reset() if IS_HIP_EXTENSION: @@ -82,9 +97,14 @@ def reset_attn_backend(): "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) yield +# Define F16 data types to test +param_types = [torch.float16] +if is_bf16_available(): + param_types.append(torch.bfloat16) +param_types_lean = [torch.bfloat16] model_configs_base = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "base_1_0": ModelConfig(8, 128, 16, 64), "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), "base_2_0": ModelConfig(2, 2048, 24, 128), @@ -109,10 +129,8 @@ def reset_attn_backend(): # backend is capable of supporting it. @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_gqa_mla_thd(): - """ - Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration - post-processing for BWD FA with native padding support. - """ + """Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration + post-processing for BWD FA with native padding support.""" # b, sq, h, dqk config = ModelConfig(8, 128, 16, 128, num_gqa_groups= 4, head_dim_v=64, attn_mask_type="padding") qkv_layout = "thd_thd_thd" @@ -129,11 +147,10 @@ def test_gqa_mla_thd(): test_dot_product_attention(dtype, {"layout_1": config}, "layout_1", False, False, qkv_layout, False, True, False) + @pytest.mark.skipif(not IS_HIP_EXTENSION, reason="ROCm TE specific pytests.") def test_dot_product_mem_calc(): - """ - Non-regression test for memory workspace calculation integer overflow issue. - """ + """Non-regression test for memory workspace calculation integer overflow issue.""" ckpt_attn = False pad_between_seqs = False if not is_bf16_compatible(): @@ -203,13 +220,18 @@ def test_dot_product_attention( if config.window_size == (-1, -1) and swa: config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) + qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if qkv_format == "thd" and "padding" not in config.attn_mask_type: + config.attn_mask_type = ( + "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" + ) + # Get backends is_training = True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -220,7 +242,6 @@ def test_dot_product_attention( config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=pad_between_seqs, is_training=is_training, ) @@ -246,7 +267,7 @@ def test_dot_product_attention( # UnfusedDotProductAttention backend if unfused_attn_supported: - unfused_attn_fwd, unfused_attn_bwd = _run_dot_product_attention( + unfused_attn_fwd, unfused_max_logit, unfused_attn_bwd = _run_dot_product_attention( dtype, config, "UnfusedDotProductAttention", @@ -260,7 +281,7 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: if len(fused_attn_backends) == 1: - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -275,7 +296,7 @@ def test_dot_product_attention( os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" os.environ["NVTE_FUSED_ATTN_CK"] = "0" os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "1" - fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention( + fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -291,7 +312,7 @@ def test_dot_product_attention( os.environ["NVTE_FUSED_ATTN_AOTRITON"] = "0" os.environ["NVTE_CK_USES_FWD_V3"] = "1" os.environ["NVTE_CK_USES_BWD_V3"] = "1" - fused_attn_fwd_1, fused_attn_bwd_1 = _run_dot_product_attention( + fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention( dtype, config, "FusedAttention", @@ -320,7 +341,7 @@ def test_dot_product_attention( # FlashAttention backend if flash_attn_supported: - flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention( + flash_attn_fwd, _, flash_attn_bwd = _run_dot_product_attention( dtype, config, "FlashAttention", @@ -332,6 +353,7 @@ def test_dot_product_attention( share_cu_seqlens_ref, ) + # Compare results logging.info(f"[test_dot_product_attention]: is_training = {is_training}") if unfused_attn_supported and flash_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs flash attn") @@ -341,6 +363,8 @@ def test_dot_product_attention( if unfused_attn_supported and fused_attn_supported: logging.info("[test_dot_product_attention]: unfused attn vs fused attn") torch.testing.assert_close(fused_attn_fwd, unfused_attn_fwd, **tols) + if config.return_max_logit: + torch.testing.assert_close(fused_max_logit, unfused_max_logit, **tols) for i, _ in enumerate(unfused_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], unfused_attn_bwd[i], **tols) if fused_attn_supported and flash_attn_supported: @@ -369,23 +393,129 @@ def test_dpa_checkpoint(dtype, model_configs, model): test_dot_product_attention(dtype, model_configs, model, True, True, None, False, False, False) +model_configs_max_logit = { + # test: ModelConfig(b, sq, hq, dqk) + "max_logit_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "max_logit_2": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), + "max_logit_3": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), + "max_logit_4": ModelConfig( + 8, 128, 16, 192, max_seqlen_kv=2048, attn_bias_type="post_scale_bias" + ), + "max_logit_5": ModelConfig( + 8, 128, 16, 512, max_seqlen_kv=2048, attn_mask_type="causal", window_size=(20, 0) + ), + "max_logit_6": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("model_configs", [model_configs_max_logit]) +@pytest.mark.parametrize("model", model_configs_max_logit.keys()) +@pytest.mark.parametrize("qkv_layout", ["sbhd_sbhd_sbhd", "thd_thd_thd"]) +def test_dpa_max_logit(dtype, model_configs, model, qkv_layout): + """Test DotProductAttention module with checkpointing""" + config = model_configs[model] + config.return_max_logit = True + test_dot_product_attention(dtype, model_configs, model, False, True, qkv_layout, False, False) + + +model_configs_softmax = { + # test: ModelConfig(b, sq, hq, dqk) + "softmax_1_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8), + "softmax_1_1": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="off-by-one"), + "softmax_1_2": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, softmax_type="learnable"), + "softmax_2_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal"), + "softmax_2_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), + "softmax_2_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), + "softmax_3_0": ModelConfig(2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding"), + "softmax_3_1": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="off-by-one" + ), + "softmax_3_2": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, attn_mask_type="padding", softmax_type="learnable" + ), + "softmax_4_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="causal" + ), + "softmax_4_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="off-by-one", + ), + "softmax_4_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="causal", + softmax_type="learnable", + ), + "softmax_5_0": ModelConfig( + 2, 2048, 64, 64, num_gqa_groups=8, window_size=(128, 0), attn_mask_type="padding_causal" + ), + "softmax_5_1": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="off-by-one", + ), + "softmax_5_2": ModelConfig( + 2, + 2048, + 64, + 64, + num_gqa_groups=8, + window_size=(128, 0), + attn_mask_type="padding_causal", + softmax_type="learnable", + ), +} + + +@pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("model_configs", [model_configs_softmax]) +@pytest.mark.parametrize("model", model_configs_softmax.keys()) +def test_dpa_softmax(dtype, model_configs, model): + """Test DotProductAttention module with different softmax types""" + test_dot_product_attention( + dtype, model_configs, model, True, True, "bshd_bshd_bshd", False, False + ) + + model_configs_mla = { - # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 - "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 - "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 + # test: ModelConfig(b, sq, hq, dqk) + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), "mla_2_1": ModelConfig( 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 - ), # cross, 1 + ), "mla_2_2": ModelConfig( 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 - ), # cross, 1 - "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference - "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference - "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference + ), + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), } @@ -399,7 +529,7 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), @@ -454,18 +584,16 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), - "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped - "bias_1_5": ModelConfig( - 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" - ), # skipped + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), + "bias_1_5": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi"), "bias_2_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_1": ModelConfig( 2, 128, @@ -474,10 +602,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_2_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_2_3": ModelConfig( 2, 2048, @@ -486,13 +614,11 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="post_scale_bias", - ), # skipped - "bias_2_4": ModelConfig( - 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), + "bias_2_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi"), "bias_2_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" - ), # skipped + ), "bias_3_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), @@ -510,14 +636,14 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), "bias_3_5": ModelConfig( 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_0": ModelConfig( 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_1": ModelConfig( 2, 128, @@ -526,10 +652,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=256, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_2": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" - ), # skipped + ), "bias_4_3": ModelConfig( 2, 2048, @@ -538,10 +664,10 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias", - ), # skipped + ), "bias_4_4": ModelConfig( 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" - ), # skipped + ), "bias_4_5": ModelConfig( 2, 2048, @@ -550,7 +676,7 @@ def test_dpa_mask(dtype, model_configs, model): max_seqlen_kv=4096, attn_mask_type="padding_causal", attn_bias_type="alibi", - ), # skipped + ), } @@ -564,7 +690,7 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { - # test: b, h, hg, d, sq, skv, p, + # test: ModelConfig(b, sq, hq, dqk) "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), @@ -602,7 +728,7 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "swa_1_1": ModelConfig(2, 2048, 16, 64), "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), @@ -642,7 +768,7 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { - # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type + # test: ModelConfig(b, sq, hq, dqk) "alibi_1_0": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" ), @@ -696,7 +822,7 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 128, 16, 64), "layout_0_1": ModelConfig( 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -744,7 +870,7 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), @@ -882,7 +1008,6 @@ def _run_dot_product_attention( share_cu_seqlens_ref: bool = False, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """Run DotProductAttention module with one forward pass and one backward pass""" - # Set RNG and environment varables reset_rng_states() os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1037,6 +1162,8 @@ def _run_dot_product_attention( layout = layout.replace("d", "dqk") tensor_shape = [dim_to_num[j] for j in layout.split("_")] tensor = 0.1 * torch.randn(tensor_shape, dtype=dtype, device="cuda") + # tensor: with padding tokens + # tensor_orig: without padding tokens tensor_orig = tensor if qkv_format == "thd" and pad_between_seqs: tensor_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1145,9 +1272,13 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tp_group=None, layer_number=1, attention_type=config.attn_type, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, ).to(dtype=dtype, device="cuda") if not is_training: block = block.eval() + if is_training and config.softmax_type != "vanilla": + block.softmax_offset.requires_grad = True cu_seqlens_q_padded = None cu_seqlens_kv_padded = None @@ -1186,14 +1317,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: alibi_slopes=alibi_slopes, fast_zero_fill=True, ) + max_logit = None + if config.return_max_logit: + out, max_logit = out if is_training: out.backward(d_out) + d_softmax_offset = None + if is_training and config.softmax_type != "vanilla": + d_softmax_offset = block.softmax_offset.grad + if backend in ["FlashAttention", "UnfusedDotProductAttention"]: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, max_logit, (None, None, None, d_softmax_offset) if backend == "FusedAttention": if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) @@ -1222,18 +1360,22 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: [v_grad_orig, v.grad[valid_range_kv[0] : valid_range_kv[1]]], dim=0 ) if is_training: - return out_orig, (q_grad_orig, k_grad_orig, v_grad_orig) + return ( + out_orig, + max_logit, + (q_grad_orig, k_grad_orig, v_grad_orig, d_softmax_offset), + ) else: - return out_orig, (None, None, None) + return out_orig, max_logit, (None, None, None, d_softmax_offset) else: if is_training: - return out, (q.grad, k.grad, v.grad) + return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: - return out, (None, None, None) + return out, max_logit, (None, None, None, d_softmax_offset) model_configs_te_layer = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), "te_1_1": ModelConfig( 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" @@ -1598,16 +1740,17 @@ def _run_transformer_layer( model_configs_fp8_extra_state = { + # test: ModelConfig(b, sq, hq, dqk) "large": ModelConfig(2, 128, 4, 128, num_layers=1), } -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") @pytest.mark.parametrize("model", ["large"]) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): +def test_dpa_fp8_extra_state(model, dtype): + """Test DotProductAttention module in FP8 with checkpointing""" config = model_configs_fp8_extra_state[model] # Test backend availability is_training = True @@ -1621,9 +1764,9 @@ def test_sanity_attention_extra_state(model, dtype): if not fused_attn_supported and not flash_attn_supported: pytest.skip("No attention backend available.") - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( + outputs = _run_dpa_fp8_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_dpa_fp8_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_dpa_fp8_extra_state( dtype, config, mimic_v1_6=True, checkpoint=True ) @@ -1645,7 +1788,8 @@ def test_sanity_attention_extra_state(model, dtype): ) -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): +def _run_dpa_fp8_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + """Run DotProductAttention module in FP8 with checkpointing""" steps = 10 path = "checkpoint.pt" fp8_enabled = True @@ -1671,7 +1815,7 @@ def get_model(dtype, config): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): + with quantized_model_init(enabled=fp8_enabled, recipe=fp8_recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -1688,7 +1832,7 @@ def get_model(dtype, config): block = get_model(dtype, config) for i in range(steps // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_enabled, recipe=fp8_recipe): output = block(hidden_states, None) loss = output.sum() loss.backward() @@ -1723,7 +1867,7 @@ def get_model(dtype, config): assert not param_grads, "Oops!" for i in range((steps + 1) // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_enabled, recipe=fp8_recipe): output = block(hidden_states, None) loss = output.sum() loss.backward() @@ -1742,7 +1886,7 @@ def get_model(dtype, config): model_configs_fp8_vs_f16 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_9": ModelConfig(2, 2048, 16, 128), "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), @@ -1762,37 +1906,9 @@ def get_model(dtype, config): qkv_format_fp8_vs_f16 = ["bshd", "sbhd"] -def _rmse(a, b): - return math.sqrt((torch.pow((a - b), 2) / a.numel()).sum()) - - -def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): - logging.debug(name_a + " min {:.6f} max {:.6f}".format(a.min().item(), a.max().item())) - logging.debug(name_b + " min {:.6f} max {:.6f}".format(b.min().item(), b.max().item())) - try: - if a.dtype != b.dtype: - a = a.to(b.dtype) - torch.testing.assert_close(a, b, atol=atol, rtol=rtol) - except Exception as e: - logging.debug(e) - - rmse = _rmse(a, b) - logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) - rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) - assert rmse < rmse_tol * rmse_range, ( - name_a - + " vs " - + name_b - + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( - rmse, rmse_tol * rmse_range, rmse_tol, rmse_range - ) - ) - - @pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("qkv_format", qkv_format_fp8_vs_f16) @@ -1800,22 +1916,44 @@ def _error(a, b, name_a, name_b, atol, rtol, rmse_tol): @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("RoPE", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_mha_fp8_vs_f16( + dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, RoPE, is_training, scaling_mode +): + """Test MultiHeadAttention module in FP8""" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + fp8_mha=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + fp8_mha=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_format.replace("hd", "h3d"), + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported - if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: - pytest.skip("Less than two backends to compare.") + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: available_backends, _, fused_attn_backends = get_available_attention_backends( config, @@ -1833,7 +1971,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") flash_attn_fwd_fp8, param_names, flash_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" @@ -1841,20 +1979,21 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, _attention_backends["backend_selection_requires_update"] = True logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = True") fused_attn_fwd_fp8, param_names, fused_attn_bwd_fp8 = _run_mha_fp8_vs_f16( - dtype, config, True, qkv_format, input_layernorm, RoPE, is_training + dtype, config, True, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) logging.info("[test_mha_fp8_vs_f16]: run with fp8_mha = False") fused_attn_fwd_f16, param_names, fused_attn_bwd_f16 = _run_mha_fp8_vs_f16( - dtype, config, False, qkv_format, input_layernorm, RoPE, is_training + dtype, config, False, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-1 rmse_tol = 0.15 - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -1862,8 +2001,11 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) - _error( + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -1871,12 +2013,13 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) if is_training: for i in range(len(param_names[:1])): logging.debug("========== {:^25s} ==========".format(param_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -1884,10 +2027,14 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, atol, rtol, rmse_tol, + True, ) -def _run_mha_fp8_vs_f16(dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training): +def _run_mha_fp8_vs_f16( + dtype, config, fp8_mha, qkv_format, input_layernorm, RoPE, is_training, fp8_recipe +): + """Run MultiHeadAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -1896,16 +2043,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_mha, - fp8_mha=fp8_mha, - ) - - with fp8_model_init(enabled=fp8_mha, recipe=fp8_recipe): + with quantized_model_init(enabled=fp8_mha, recipe=fp8_recipe): rotary_pos_emb = None if RoPE: PE = RotaryPositionEmbedding(dim=config.head_dim_qk) @@ -1977,7 +2115,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: tensor = 0.01 * torch.randn(tensor_shape, dtype=dtype, device="cuda") out_grad = tensor.view(*tensor.shape[:-2], -1) - with fp8_autocast(enabled=fp8_mha, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_mha, recipe=fp8_recipe): out = mha( hidden_states, attn_mask_type=config.attn_mask_type, @@ -2007,14 +2145,15 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: @pytest.mark.skipif(IS_HIP_EXTENSION, reason="FP8 Fused attention is not supported on ROCm") @pytest.mark.skipif(get_cudnn_version() < (9, 2, 1), reason="cuDNN 9.2.1+ is required.") -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8_vs_f16) @pytest.mark.parametrize("model", model_configs_fp8_vs_f16.keys()) @pytest.mark.parametrize("qkv_layout", qkv_layout_fp8_vs_f16) @pytest.mark.parametrize("fp8_dpa_bwd", [True, False]) @pytest.mark.parametrize("is_training", [True, False]) -def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): +@pytest.mark.parametrize("scaling_mode", ["delayed", "current"]) +def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training, scaling_mode): + """Test DotProductAttention module in FP8""" config = model_configs_fp8_vs_f16[model] # TODO(cyang): think of another way to verify dropout results @@ -2029,16 +2168,33 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "1" # Test backend availability + if scaling_mode == "delayed": + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=True, + ) + elif scaling_mode == "current": + fp8_recipe = recipe.Float8CurrentScaling( + fp8_format=recipe.Format.HYBRID, + fp8_dpa=True, + ) + fp8_meta = {} + fp8_meta["recipe"] = fp8_recipe available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, qkv_layout=qkv_layout, + fp8=True, + fp8_meta=fp8_meta, is_training=is_training, ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # Skip if only unfused backend is supported if flash_attn_supported + fused_attn_supported < 1: pytest.skip("No FP8 attention backend available.") if not fp8_dpa_bwd: @@ -2058,33 +2214,45 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FlashAttention)") flash_attn_fwd_fp8, flash_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe + ) + + if unfused_attn_supported: + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (UnfusedDotProductAttention)") + unfused_attn_fwd_fp8, unfused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( + dtype, config, True, qkv_layout, is_training, fp8_recipe ) os.environ["NVTE_FLASH_ATTN"] = "0" os.environ["NVTE_FUSED_ATTN"] = "1" _attention_backends["backend_selection_requires_update"] = True - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = True (FusedAttention)") fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_dpa_fp8_vs_f16( - dtype, config, True, qkv_layout, is_training + dtype, config, True, qkv_layout, is_training, fp8_recipe ) + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "1" if config.dropout_p == 0.0: # test cuDNN FP8 dropout: need a FP16/BF16 reference on Blackwell - logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False") + logging.info("[test_dpa_fp8_vs_f16]: run with fp8_dpa = False (FusedAttention)") fused_attn_fwd_f16, fused_attn_bwd_f16 = _run_dpa_fp8_vs_f16( - dtype, config, False, qkv_layout, is_training + dtype, config, False, qkv_layout, is_training, fp8_recipe ) atol = 5e-1 rtol = 5e-2 rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] - logging.debug("========== {:^25s} ==========".format("forward output")) if flash_attn_supported: - _error( + logging.debug("========== {:^25s} ==========".format("flash fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( flash_attn_fwd_fp8, fused_attn_fwd_f16, "flash_attn_fwd_fp8", @@ -2092,14 +2260,43 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, + ) + if unfused_attn_supported: + logging.debug("========== {:^25s} ==========".format("unfused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( + unfused_attn_fwd_fp8, + fused_attn_fwd_f16, + "unfused_attn_fwd_fp8", + "fused_attn_fwd_f16", + atol, + rtol, + rmse_tol, + True, ) + if is_training: + for i, _ in enumerate(fused_attn_bwd_f16): + logging.debug("========== {:^25s} ==========".format(bwd_names[i])) + compare_and_assert( + unfused_attn_bwd_fp8[i], + fused_attn_bwd_f16[i], + f"unfused_attn_bwd_fp8[{i}]", + f"fused_attn_bwd_f16[{i}]", + atol, + rtol, + rmse_tol, + True, + ) if config.dropout_p != 0.0: # test cuDNN FP8 dropout assert torch.all( fused_attn_fwd_fp8 == 1 ), "fused_attn_fwd_fp8 must be all 1s when Q/K/V are all 1s." else: - _error( + logging.debug("========== {:^25s} ==========".format("fused fp8 vs fused f16:")) + logging.debug("========== {:^25s} ==========".format("forward output")) + compare_and_assert( fused_attn_fwd_fp8, fused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -2107,11 +2304,12 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) if is_training: for i, _ in enumerate(fused_attn_bwd_f16): logging.debug("========== {:^25s} ==========".format(bwd_names[i])) - _error( + compare_and_assert( fused_attn_bwd_fp8[i], fused_attn_bwd_f16[i], f"fused_attn_bwd_fp8[{i}]", @@ -2119,11 +2317,13 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): atol, rtol, rmse_tol, + True, ) + os.environ["NVTE_UnfusedDPA_Emulate_FP8"] = "0" -def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training): - +def _run_dpa_fp8_vs_f16(dtype, config, fp8_dpa, qkv_layout, is_training, fp8_recipe): + """Run DotProductAttention module in FP8""" reset_rng_states() _DUMMY_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() _DUMMY_CUDA_RNG_STATE_TRACKER.add("model-parallel-rng", seed) @@ -2132,16 +2332,8 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: """Get cuda rng tracker.""" return _DUMMY_CUDA_RNG_STATE_TRACKER - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_dpa, - ) - qkv_format = "".join([i for i in qkv_layout.split("_")[0] if i.isalpha()]) - with fp8_model_init(enabled=fp8_dpa): + with quantized_model_init(enabled=fp8_dpa): dpa = DotProductAttention( config.num_heads, config.head_dim_qk, @@ -2233,7 +2425,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: out_grad_shape_new = [*out_grad_shape[:-2], out_grad_shape[-2] * out_grad_shape[-1]] out_grad = torch.randn(out_grad_shape_new, dtype=dtype, device="cuda") - with fp8_autocast(enabled=fp8_dpa, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8_dpa, recipe=fp8_recipe): out = dpa( inp[0], inp[1], @@ -2246,6 +2438,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: attn_mask_type=config.attn_mask_type, checkpoint_core_attention=False, core_attention_bias_type=config.attn_bias_type, + fp8_output=fp8_dpa, ) if is_training: out.backward(out_grad) @@ -2256,7 +2449,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "fp8_1": ModelConfig(1, 512, 1, 64), "fp8_2": ModelConfig(4, 512, 16, 64), "fp8_3": ModelConfig(1, 2048, 1, 128), @@ -2281,8 +2474,7 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: ), reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", ) -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper+.") +@pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8) @pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) def test_custom_mha_fp8_vs_f16(dtype, model): @@ -2312,7 +2504,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol = 5e-1 rtol = 5e-1 rmse_tol = 0.13 - _error( + compare_and_assert( fused_attn_fwd_fp8, unfused_attn_fwd_f16, "fused_attn_fwd_fp8", @@ -2320,8 +2512,9 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) - _error( + compare_and_assert( fused_attn_bwd_fp8, unfused_attn_bwd_f16, "fused_attn_bwd_fp8", @@ -2329,6 +2522,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): atol, rtol, rmse_tol, + True, ) @@ -2372,7 +2566,7 @@ def _run_custom_mha_fp8(dtype, config, backend): ) mha = Custom_MHA_FP8(config).to(dtype=dtype, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with autocast(enabled=True, recipe=fp8_recipe): out = mha(inp, cu_seqlens, config.max_seqlen_q) out.backward(out_grad) @@ -2570,7 +2764,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) proj_dgrad = ctx.dO_quantizer(grad_output) - fp8_dtype_backward = fp8.get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) + fp8_dtype_backward = get_fp8_te_dtype(ctx.fp8_meta["recipe"], fprop_tensor=False) dq, dk, dv, *rest = fused_attn_bwd( ctx.max_s, diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index ece5a37de..d0956c226 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -8,27 +8,39 @@ import subprocess import sys import pathlib +import logging import pytest import torch + from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.utils import ( + +from transformer_engine.pytorch import ( get_device_compute_capability, get_cudnn_version, ) +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) from utils import ModelConfig, get_available_attention_backends +pytest_logging_level = logging.getLevelName(logging.root.level) + # Initialize RNG state seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) +test_essential = True + model_configs_flash_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias + # test: ModelConfig(b, sq, hq, dqk) "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA @@ -63,18 +75,31 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): return args +dtypes = ["bf16", "fp16"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_2_1", "cp_3_2", "cp_3_3"] + model_configs_flash_attn = {k: model_configs_flash_attn[k] for k in configs} + dtypes = ["bf16"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(not FlashAttentionUtils.v2_plus, reason="Flash-attn 2.0+ is required.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") config = model_configs_flash_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1): pytest.skip("CP implementation with KV P2P does not support sliding window yet!") if cp_comm_type == "all_gather" and qkv_format == "thd": @@ -95,6 +120,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): if IS_HIP_EXTENSION: if config.head_dim_qk != config.head_dim_v and not FlashAttentionUtils.v3_is_installed: pytest.skip("MLA FlashAttention requires v3+!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16} + available_backends, *_ = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + ) + flash_attn_supported, *_ = available_backends + if not flash_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( @@ -104,14 +138,15 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + log_level=pytest_logging_level, ), check=True, ) model_configs_fused_attn = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA - "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + # test: ModelConfig(b, sq, hq, dqk) + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", return_max_logit=not IS_HIP_EXTENSION), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128, return_max_logit=not IS_HIP_EXTENSION), # MHA "cp_1_2": ModelConfig( 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA @@ -140,17 +175,42 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA + "cp_4_0": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="vanilla" + ), # GQA + "cp_4_1": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="off-by-one" + ), # GQA + "cp_4_2": ModelConfig( + 2, 4096, 64, 64, num_gqa_groups=8, attn_mask_type="causal", softmax_type="learnable" + ), # GQA } +dtypes = ["bf16", "fp16", "fp8"] +qkv_formats = ["bshd", "sbhd", "thd"] +cp_comm_types = ["p2p", "all_gather", "a2a", "a2a+p2p"] +if test_essential: + configs = ["cp_1_0", "cp_1_1", "cp_2_0", "cp_2_2", "cp_3_2", "cp_4_2"] + model_configs_fused_attn = {k: model_configs_fused_attn[k] for k in configs} + dtypes = ["bf16", "fp8"] + qkv_formats = ["sbhd", "thd"] + + @pytest.mark.skipif(get_cudnn_version() < (8, 9, 7), reason="cuDNN 8.9.7+ is required.") @pytest.mark.skipif(not IS_HIP_EXTENSION and get_device_compute_capability() < (8, 0), reason="CP tests require sm80+.") -@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"]) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("model", model_configs_fused_attn.keys()) -@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"]) -@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"]) -@pytest.mark.parametrize("fp8_mha", [False, True]) -def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha): +@pytest.mark.parametrize("qkv_format", qkv_formats) +@pytest.mark.parametrize("cp_comm_type", cp_comm_types) +@pytest.mark.parametrize("fp8_bwd", [True, False]) +@pytest.mark.parametrize("fp8_mha", [True, False]) +@pytest.mark.parametrize("fp8_dpa", [True, False]) +@pytest.mark.parametrize("scaling_mode", [None, "delayed", "current"]) +@pytest.mark.parametrize("f16_O", [True, False]) +def test_cp_with_fused_attention( + dtype, model, qkv_format, cp_comm_type, fp8_bwd, fp8_mha, fp8_dpa, scaling_mode, f16_O +): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 if num_gpus > torch.cuda.device_count(): pytest.skip(f"Test requires {num_gpus} GPUs, but found {torch.cuda.device_count()}") @@ -162,9 +222,16 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha if (not IS_HIP_EXTENSION) and dtype == "fp8" and get_device_compute_capability() < (9, 0): pytest.skip("FP8 attention is only supported on sm90+!") if IS_HIP_EXTENSION and dtype == "fp8": - pytest.skip("FP8 attention has not been supported on ROCm yet!") + pytest.skip("FP8 attention is not supported on ROCm yet!") + if dtype == "fp8" and not fp8_dpa and fp8_mha: + pytest.skip("Duplicate tests to fp8_dpa=True and fp8_mha=True!") + if dtype != "fp8" and fp8_bwd: + pytest.skip("Only fp8 works with fp8_bwd=True!") config = model_configs_fused_attn[model] + config.context_parallel = True + config.cp_comm_type = cp_comm_type + if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias": pytest.skip("THD format does not support post_scale_bias yet!") if qkv_format == "thd" and cp_comm_type == "all_gather": @@ -192,19 +259,57 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) - if dtype != "fp8" and fp8_mha: - pytest.skip("Only fp8 works with fp8_mha=True!") + if dtype != "fp8" and (fp8_mha or fp8_dpa): + pytest.skip("Only fp8 works with fp8_dpa=True or fp8_mha=True!") + if dtype == "fp8" and not (fp8_mha or fp8_dpa): + pytest.skip("fp8 only works with fp8_dpa=True or fp8_mha=True!") + if dtype != "fp8" and scaling_mode is not None: + pytest.skip("Only fp8 works with scaling_mode != None!") + if dtype == "fp8" and scaling_mode is None: + pytest.skip("fp8 only works with scaling_mode != None!") + if ( + dtype == "fp8" + and scaling_mode == "current" + and cp_comm_type not in ["p2p", "a2a+p2p", "a2a"] + ): + pytest.skip("fp8 only works with P2P, A2A and A2A+P2P for scaling_mode = current!") + if f16_O and (dtype != "fp8" or scaling_mode != "current"): + pytest.skip("f16_O only needs to be tested for dtype = fp8 and scaling_mode = current!") if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + if dtype == "fp8" and config.softmax_type != "vanilla": + pytest.skip("CP implementation does not support non-vanilla softmax types in FP8!") + if config.softmax_type != "vanilla" and cp_comm_type != "a2a": + pytest.skip( + "CP implementation only supports cp_comm_type=a2a for non-vanilla softmax types!" + ) + if config.softmax_type != "vanilla" and qkv_format == "thd": + pytest.skip( + "CP implementation does not support qkv_format=thd for non-vanilla softmax types!" + ) + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + fp8_meta = {} + fp8_meta["recipe"] = None + fp8_meta["local_recipes"] = [] + fp8 = dtype == "fp8" and (fp8_dpa or fp8_mha) + if fp8 and scaling_mode == "delayed": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [DelayedScaling(fp8_dpa=True)] + if fp8 and scaling_mode == "current": + fp8_meta["recipe"] = DelayedScaling(fp8_dpa=True) + fp8_meta["local_recipes"] = [ + Float8CurrentScaling(fp8_dpa=True), + DelayedScaling(fp8_dpa=True), + ] available_backends, _, fused_attn_backends = get_available_attention_backends( config, - qkv_dtype=dtypes[dtype], + qkv_dtype=dtypes[dtype] if dtype != "fp8" else get_torch_float8_e4m3_type(), qkv_layout="_".join([qkv_format] * 3), - window_size=config.window_size, - context_parallel=True, + fp8=fp8, + fp8_meta=fp8_meta, ) _, fused_attn_supported, _ = available_backends if not fused_attn_supported: @@ -218,7 +323,12 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha qkv_format=qkv_format, kernel_backend="FusedAttention", cp_comm_type=cp_comm_type, + fp8_bwd=fp8_bwd, + fp8_dpa=fp8_dpa, fp8_mha=fp8_mha, + scaling_mode=scaling_mode, + f16_O=f16_O, + log_level=pytest_logging_level, ), check=True, ) diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py index 00200c62d..0dd5ba601 100644 --- a/tests/pytorch/attention/test_cp_utils.py +++ b/tests/pytorch/attention/test_cp_utils.py @@ -5,7 +5,6 @@ """Unit tests for context parallel utils.""" import torch import unittest -from typing import Tuple from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( get_batch_on_this_cp_rank, pad_thd_sequences_for_cp, diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index af71866f3..eb86c0776 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -15,22 +15,25 @@ import pytest import torch -from torch.distributions import Exponential from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch import make_graphed_callables -from transformer_engine.common import recipe -from transformer_engine.pytorch import fp8_autocast, fp8_model_init -from transformer_engine.pytorch.transformer import ( + +from torch.distributions import Exponential +from transformer_engine.pytorch import ( + make_graphed_callables, + autocast, + quantized_model_init, TransformerLayer, + DotProductAttention, + InferenceParams, + is_bf16_available, ) -from transformer_engine.pytorch.attention import DotProductAttention, InferenceParams +from transformer_engine.common import recipe from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, ) from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, - is_bf16_compatible, ) _current_file = pathlib.Path(__file__).resolve() @@ -45,7 +48,7 @@ reset_rng_states() param_types = [torch.float16] -if is_bf16_compatible(): +if is_bf16_available(): param_types.append(torch.bfloat16) model_configs_infer = { @@ -241,7 +244,7 @@ def get_model( if module == "TransformerLayer": hidden_size = config.head_dim_qk * config.num_heads - with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): + with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe): model = [ TransformerLayer( hidden_size=hidden_size, @@ -264,7 +267,7 @@ def get_model( for layer_number in range(1, num_layers + 1) ] if module == "DotProductAttention": - with fp8_model_init(enabled=is_fp8, recipe=fp8_recipe): + with quantized_model_init(enabled=is_fp8, recipe=fp8_recipe): model = [ DotProductAttention( kv_channels=config.head_dim_qk, @@ -483,7 +486,6 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g config, qkv_dtype=dtype, qkv_layout=qkv_layout, - window_size=config.window_size, pad_between_seqs=False, is_training=False, fp8=is_fp8, @@ -574,9 +576,9 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g model[i], sample_args, num_warmup_iters=10, - fp8_enabled=is_fp8, + enabled=is_fp8, sample_kwargs=sample_kwargs, - fp8_recipe=fp8_recipe, + recipe=fp8_recipe, ) for i in range(num_layers) ] @@ -669,7 +671,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g if inference_params.is_paged: inference_params.cache_manager.print_cache() incremental_output = incremental_inputs - with fp8_autocast(enabled=is_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=is_fp8, recipe=fp8_recipe): for m in model: incremental_output = m( *incremental_output, diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 716c16056..358841943 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -16,7 +16,7 @@ import transformer_engine_torch as tex import nvdlfw_inspect.api as debug_api from transformer_engine.debug import set_weight_tensor_tp_group_reduce -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import is_fp8_available from test_numerics import ( _emulate_linear, @@ -45,7 +45,7 @@ all_boolean = [True, False] TEST_NR = 0 -fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +fp8_available = is_fp8_available() def _get_tensors(parallel_mode, weight_seed=SEED, data_seed=SEED, tp_size=None, tp_rank=None): @@ -117,7 +117,7 @@ def backward(ctx, grad_output): def _run_forward_backward(x, model, parallel_mode=None, group=None): - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE): y = model(x) y.requires_grad_(True) @@ -413,13 +413,13 @@ def test_log_expert_parallel(**kwargs): ) # data parallel model = _init_model(weight, parallel_mode=None, name="linear1") model1 = _init_model(weight, parallel_mode=None, name="linear2") - with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE): y1 = model(x) y2 = model1(x) y = y1 + y2 y.sum().backward() debug_api.step() - with transformer_engine.pytorch.fp8_autocast(enabled=fp8_available, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=fp8_available, recipe=FP8_RECIPE): y = model(x) if WORLD_RANK != 0: y = y + model1(x) @@ -532,7 +532,7 @@ def test_per_tensor_scaling( LOSS_MULTIPLIER = 100 - with transformer_engine.pytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + with transformer_engine.pytorch.autocast(enabled=True, recipe=FP8_RECIPE): y = model(x) model.zero_grad() if parallel_mode == "column": @@ -668,11 +668,12 @@ def _run_test_with_combinations( _init_distributed() test_log_expert_parallel() - for parallel_mode in ["column", "row"]: - for gather_weight in [True, False]: - test_log_distributed(parallel_mode, gather_weight) if fp8_available: + for parallel_mode in ["column", "row"]: + for gather_weight in [True, False]: + test_log_distributed(parallel_mode, gather_weight) + for parallel_mode in ["row", "column"]: test_disable_fp8_layer(parallel_mode) diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index d28db1647..fbf619d48 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -3,7 +3,7 @@ # See LICENSE for license information. import torch -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer +from transformer_engine.pytorch import Float8Tensor, Float8Quantizer import nvdlfw_inspect.api as debug_api diff --git a/tests/pytorch/debug/test_config.py b/tests/pytorch/debug/test_config.py index 71715a686..9b6bcd1cd 100644 --- a/tests/pytorch/debug/test_config.py +++ b/tests/pytorch/debug/test_config.py @@ -1,7 +1,7 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -import pathlib, os +import pathlib from nvdlfw_inspect.config_manager import ConfigManager diff --git a/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml b/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml new file mode 100644 index 000000000..224be4618 --- /dev/null +++ b/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml @@ -0,0 +1,11 @@ +test_switch_to_nondebug_mode: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + TestDummyFeature: + enabled: True + inspect_only_once: True + tensors: [weight, activation, gradient, output, wgrad, dgrad] + gemms: [wgrad, dgrad, fprop] + diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index dcc9861c8..0f833d41f 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -8,18 +8,26 @@ import torch import tempfile from transformer_engine.common import recipe -from transformer_engine.pytorch.fp8 import RecipeState import pytest import contextlib import os -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import ( + is_fp8_available, + is_mxfp8_available, + is_fp8_block_scaling_available, +) +from transformer_engine.pytorch.quantization import RecipeState from transformer_engine.debug.pytorch.debug_state import TEDebugState +from transformer_engine.debug.features.utils.stats_computation import ( + compute_max_blockwise_dynamic_range, + BlockwiseDynamicRangeStat, +) +import math - -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( + return_reason=True ) LOG_QUANTIZED_CONFIG_BASE = """ @@ -128,7 +136,7 @@ def test_sanity(feature_dirs): inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() for _ in range(10): - with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): + with te.autocast(recipe=recipe.DelayedScaling()): output = model(inp) loss = output.sum() loss.backward() @@ -150,7 +158,7 @@ def test_sanity(feature_dirs): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -def test_numerics(fp8_recipe, feature_dirs): +def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if not fp8_available: pytest.skip(reason_for_no_fp8) if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): @@ -206,6 +214,107 @@ def test_numerics(fp8_recipe, feature_dirs): assert overflows == pytest.approx(expected.cpu(), abs=1e-4) +LOG_HIGH_PRECISION_CONFIG = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogTensorStats: + enabled: True + stats: + - dynamic_range + - max_blockwise_dynamic_range: + block_size: 4 + dims: 1 + - max_blockwise_dynamic_range: + block_size: 4 + dims: 2 + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" + + +@pytest.mark.parametrize("tensor_name", ["activation", "weight", "gradient"]) +def test_log_stats_numerics(feature_dirs, tensor_name): + """Check correctness of dynamic range and max blockwise dynamic range stats. + + Tests different tensor types: + - activation/weight: use both orientations (rowwise + columnwise), takes max + - gradient/dgrad: use single orientation (rowwise only) + """ + log_only_bare_stats_config = LOG_HIGH_PRECISION_CONFIG + + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: + # There is 1024 x 1024 tensor with very small epsilon values in almost all elements, + # one row of large value A and three rows of large value B. + epsilon = 1e-10 + A = 1000 + B = 50 + tensor = torch.zeros(1024, 1024).cuda() + epsilon + tensor[0, :] = A + tensor[1:4, :] = B + + debug_api.transformer_engine.inspect_tensor( + layer_name="layer_name", + tensor_name=tensor_name, + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=None, + rowwise_quantized_tensor=None, + columnwise_quantized_tensor=None, + ) + debug_api.step() + + output = read_log(log_dir) + + max_over_orientations = tensor_name in ["activation", "weight"] + max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else "" + + # Track which stats were found to ensure all are present + found_dims_1 = False + found_dims_2 = False + found_dynamic_range = False + + for line in output.splitlines(): + if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line: + max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) + if max_over_orientations: + # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) + expected = math.log2(A) - math.log2(B) + else: + # Rowwise blocks have uniform values -> dynamic_range = 0 + expected = 0 + assert max_blockwise_dynamic_range_block_size_4_dims_1 == pytest.approx( + expected, abs=1e-4 + ) + found_dims_1 = True + elif ( + f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line + ): + max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) + # For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows + expected = math.log2(A) - math.log2(B) + assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( + expected, abs=1e-4 + ) + found_dims_2 = True + elif "_dynamic_range" in line and "max_blockwise_dynamic_range" not in line: + dynamic_range = float(line.split("value=")[1]) + expected = math.log2(A) - math.log2(epsilon) + assert dynamic_range == pytest.approx(expected, abs=1e-4) + found_dynamic_range = True + + # Ensure all expected stats were found in the output + assert found_dims_1, "max_blockwise_dynamic_range (dims=1) not found in output" + assert found_dims_2, "max_blockwise_dynamic_range (dims=2) not found in output" + assert found_dynamic_range, "dynamic_range not found in output" + + @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): if not fp8_available: @@ -232,7 +341,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): for i in range(20): x = torch.randn(4, 128, 128).cuda() - with te.fp8_autocast(enabled=True): + with te.autocast(enabled=True): y = model(x) y.sum().backward() debug_api.step() @@ -252,3 +361,92 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): debug_api.end_debug() TEDebugState._reset() + + +def test_compute_max_blockwise_dynamic_range_direct(): + """Direct unit test for compute_max_blockwise_dynamic_range function. + + Tests the function with various configurations to ensure correct behavior + for different block sizes, dimensions, and orientation settings. + """ + # Create test tensor with uniform rows but mixed columns + # Row 0: all 1000, Row 1-3: all 50, remaining: all 0.01 + epsilon = 0.01 + A = 1000.0 + B = 50.0 + tensor = torch.zeros(1024, 1024).cuda() + epsilon + tensor[0, :] = A + tensor[1:4, :] = B + + # Test 1: dims=1, max_over_orientations=False (rowwise only) + # Rowwise blocks have uniform values -> dynamic_range should be 0 + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + assert result.item() == pytest.approx( + 0.0, abs=1e-4 + ), "Rowwise 1D blocks with uniform values should have dynamic_range=0" + + # Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise) + # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + expected = math.log2(A) - math.log2(B) + assert result.item() == pytest.approx(expected, abs=1e-4), ( + f"Max over orientations should capture columnwise dynamic_range, expected {expected}, got" + f" {result.item()}" + ) + + # Test 3: dims=2, block_size=4 (4x4 tiles) + # 2D blocks span multiple rows -> always have mixed values + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + expected = math.log2(A) - math.log2(B) + assert result.item() == pytest.approx(expected, abs=1e-4), ( + f"2D blocks should capture mixed values from different rows, expected {expected}, got" + f" {result.item()}" + ) + + # Test 4: Different block size + # With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon] + # So max=A, min=epsilon (not B anymore) + stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True) + result = compute_max_blockwise_dynamic_range(tensor, stat_config) + expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B + assert result.item() == pytest.approx( + expected, abs=1e-4 + ), f"Block size 8 should work correctly, expected {expected}, got {result.item()}" + + # Test 5: Tensor with all uniform values -> dynamic_range should be 0 + uniform_tensor = torch.ones(64, 64).cuda() * 42.0 + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) + result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config) + assert result.item() == pytest.approx( + 0.0, abs=1e-4 + ), "Uniform tensor should have dynamic_range=0" + + # Test 6: 3D tensor flattening validation using 2D/3D comparison + # Create a 4x4 tensor with distinct 2x2 blocks, compute with dims=2, block_size=2 + # Then reshape to 3D and compute again - results should match if flattening is correct + tensor_2d = torch.tensor( + [ + [1.0, 1.0, 10.0, 10.0], + [1.0, 1.0, 10.0, 10.0], + [100.0, 100.0, 1000.0, 1000.0], + [100.0, 100.0, 1000.0, 1000.0], + ] + ).cuda() + + # Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100) + stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False) + result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config) + + # Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct + tensor_3d = tensor_2d.reshape(2, 2, 4) + result_3d = compute_max_blockwise_dynamic_range(tensor_3d, stat_config) + + assert result_2d.item() == pytest.approx(result_3d.item(), abs=1e-6), ( + "3D tensor [2,2,4] flattened to [4,4] must give same result as original 2D, got" + f" 2D={result_2d.item()}, 3D={result_3d.item()}" + ) + + print("All direct tests for compute_max_blockwise_dynamic_range passed!") diff --git a/tests/pytorch/debug/test_numerics.py b/tests/pytorch/debug/test_numerics.py index 749fa16bc..2ad2c8fb8 100644 --- a/tests/pytorch/debug/test_numerics.py +++ b/tests/pytorch/debug/test_numerics.py @@ -17,19 +17,19 @@ import transformer_engine.pytorch as tepytorch import transformer_engine_torch as tex from transformer_engine.common.recipe import DelayedScaling, Format -from transformer_engine.pytorch.fp8 import _default_sf_compute -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch.quantization import _default_sf_compute +from transformer_engine.pytorch import ( Float8Quantizer, Float8CurrentScalingQuantizer, + is_fp8_available, ) from transformer_engine.pytorch.module.base import ( _2X_ACC_DGRAD, _2X_ACC_FPROP, _2X_ACC_WGRAD, ) -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) all_boolean = [True, False] FP8_FORMAT = Format.HYBRID @@ -250,7 +250,7 @@ def _init_model(weight): def _run_forward_backward(x, model, loss_scale=1.0, is_first_microbatch=None, fp8=True): - with tepytorch.fp8_autocast(enabled=fp8, fp8_recipe=FP8_RECIPE): + with tepytorch.autocast(enabled=fp8, recipe=FP8_RECIPE): y = model(x, is_first_microbatch=is_first_microbatch) (y.sum() * loss_scale).backward() debug_api.step() @@ -547,7 +547,7 @@ def run_per_tensor_scaling( LOSS_MULTIPLIER = 100 - with tepytorch.fp8_autocast(enabled=True, fp8_recipe=FP8_RECIPE): + with tepytorch.autocast(enabled=True, recipe=FP8_RECIPE): y = model(x, is_first_microbatch=True) model.zero_grad() y.retain_grad() diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index 2d4b62b23..c8c9ae3c1 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -6,71 +6,70 @@ import pytest import torch import transformer_engine.pytorch as te -import time import nvdlfw_inspect.api as debug_api from transformer_engine.debug.pytorch.debug_state import TEDebugState -def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): - debug_api.end_debug() - TEDebugState._reset() - if debug_tools_initialized: - # This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS. - # So after 1 warm-up iteration, this layers should work in non-debug mode. - debug_api.initialize( - config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs - ) - - try: - if layer == "linear": - model = torch.nn.Sequential( - te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") - ).cuda() - NUM_ITERS = 18000 - elif layer == "transformer": - model = torch.nn.Sequential( - te.TransformerLayer(1, 1, 1, name="transformer1"), - te.TransformerLayer(1, 1, 1, name="transformer2"), - ).cuda() - NUM_ITERS = 2000 +@pytest.mark.parametrize("use_microbatching", [False, True]) +def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbatching): + """ + Test that layers switch to non-debug mode when no features are active. - x = torch.randn(1, 1, 1).cuda() + Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None). + The TE should: + 1. Call inspect_tensor_enabled to check if feature is needed + 2. Never call inspect_tensor + 3. Allow layers to switch to non-debug mode for optimal performance, + so that inspect_tensor_enabled is never called again. - y = model(x) - y.sum().backward() - debug_api.step() - torch.cuda.synchronize() + Tests both with and without microbatching to ensure proper behavior in both scenarios. + """ - time_start = time.time() - for i in range(NUM_ITERS): - y = model(x) + try: + debug_api.initialize( + config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml", + feature_dirs=feature_dirs, + ) + import transformer_engine.debug.features._test_dummy_feature as dummy_feature + + # Reset counters + dummy_feature._inspect_tensor_enabled_call_count = 0 + dummy_feature._inspect_tensor_call_count = 0 + + model = te.Linear(256, 256, name="test_linear").cuda() + x = torch.randn(8, 256, 256).cuda() + + # Run multiple iterations + for i in range(20): + if use_microbatching: + # Alternate between first and non-first microbatch + is_first_microbatch = i % 2 == 0 + y = model(x, is_first_microbatch=is_first_microbatch) + else: + # Run without specifying is_first_microbatch + y = model(x) y.sum().backward() - if debug_tools_initialized: - debug_api.step() - torch.cuda.synchronize() - time_end = time.time() - - finally: - if debug_tools_initialized: - debug_api.end_debug() - - return time_end - time_start - - -@pytest.mark.parametrize("layer", ["linear", "transformer"]) -def test_cpu_overhead(layer, configs_dir, feature_dirs): - # runs one layer many times on very small tensor - # - gpu time should be negligible, so time should be dominated by cpu time. - # if layers does not invoke any feature in current iteration, - # then it changed into non-debug mode and should not have any non-negligible cpu overhead - # compared to layer without debug tools initialized. - - with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs) - without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs) + debug_api.step() + + # Verify inspect_tensor_enabled was called only once per tensor + # (activation, weight, gradient, output, wgrad, dgrad) + enabled_call_count = dummy_feature._inspect_tensor_enabled_call_count + microbatch_info = "with microbatching" if use_microbatching else "without microbatching" + assert enabled_call_count == 6, ( + f"inspect_tensor_enabled was called {enabled_call_count} times ({microbatch_info}), " + "but should be called 6 times to check if feature is needed for each tensor " + "(activation, weight, gradient, output, wgrad, dgrad)" + ) - print(f"with_debug_tools: {with_debug_tools} s") - print(f"without_debug_tools: {without_debug_tools} s") + # Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None) + inspect_call_count = dummy_feature._inspect_tensor_call_count + assert inspect_call_count == 0, ( + f"inspect_tensor was called {inspect_call_count} times ({microbatch_info}), " + "but should never be called when inspect_tensor_enabled returns (False, None)" + ) - assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin + finally: + debug_api.end_debug() + TEDebugState._reset() diff --git a/tests/pytorch/debug/test_sanity.py b/tests/pytorch/debug/test_sanity.py index e4ce35be6..97be3003d 100644 --- a/tests/pytorch/debug/test_sanity.py +++ b/tests/pytorch/debug/test_sanity.py @@ -7,11 +7,10 @@ import nvdlfw_inspect.api as debug_api import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from test_numerics import create_config_file -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) B, S, H, D = 64, 64, 64, 64 @@ -68,7 +67,7 @@ def _get_model(model_key): def _run_forward_backward(model, fp8): for _ in range(3): inp = torch.randn((S, B, H)).cuda() - with te.fp8_autocast(enabled=fp8): + with te.autocast(enabled=fp8): out = model(inp) out.sum().backward() debug_api.step() diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py deleted file mode 100644 index 875905c78..000000000 --- a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py +++ /dev/null @@ -1,684 +0,0 @@ -#!/usr/bin/python3 - -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import argparse -import datetime -import os -import sys - -import torch -from torch import nn -import torch.distributed as dist - -from transformer_engine.common.recipe import ( - DelayedScaling, - Float8CurrentScaling, - Float8BlockScaling, - Format, - Recipe, -) -import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8 -from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Tensor, - Float8CurrentScalingQuantizer, -) -from transformer_engine.pytorch.tensor.utils import replace_raw_data -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor - - -def _get_raw_data(quantized_tensor): - """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" - if isinstance(quantized_tensor, Float8Tensor): - assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" - assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" - return quantized_tensor._data - elif isinstance(quantized_tensor, Float8BlockwiseQTensor): - assert hasattr( - quantized_tensor, "_rowwise_data" - ), "Float8BlockwiseQTensor does not have _rowwise_data attribute" - assert ( - quantized_tensor._rowwise_data.dtype == torch.uint8 - ), "Float8BlockwiseQTensor _rowwise_data must be uint8" - return quantized_tensor._rowwise_data - else: - raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") - - -class MiniZero_1: - """A mini zero-1 optimizer implementation, just used for this test""" - - def __init__(self, weights, lr, dp_group): - self.rank = dist.get_rank(dp_group) - self.world_size = dist.get_world_size(dp_group) - - self.weights = weights - self.lr = lr - self.dp_group = dp_group - - # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer - self.offsets = [0] - for weight in self.weights: - self.offsets.append(self.offsets[-1] + weight.numel()) - - # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may - # not be the end range of the last weight. - if self.offsets[-1] % self.world_size != 0: - self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size - - self.master_weights = [] - # The start offset of the master weight in the weight - self.start_offsets = [] - # The overlapping area of the weight and this rank's local buffer - self.overlapping_areas = [] - - # The start and end of this rank's local buffer in the global buffer - rank_start = self.offsets[-1] // self.world_size * self.rank - rank_end = rank_start + self.offsets[-1] // self.world_size - - for weight, offset in zip(self.weights, self.offsets[:-1]): - if offset >= rank_end or (offset + weight.numel()) <= rank_start: - # This weight is not in this rank's local buffer - master_weight = None - start_offset = None - overlapping_area = None - else: - overlapping_start = max(rank_start, offset) - overlapping_end = min(rank_end, offset + weight.numel()) - length = overlapping_end - overlapping_start - start_offset = overlapping_start - offset - if isinstance(weight, QuantizedTensor): - # If weight is a FP8 tensor, we need to use the original high precision version - # to initialize the master weight. - high_precision_init_val = weight.get_high_precision_init_val().view(-1) - master_weight = high_precision_init_val.to(weight.device).float()[ - start_offset : start_offset + length - ] - else: - master_weight = ( - weight.detach().view(-1).float()[start_offset : start_offset + length] - ) - overlapping_area = (overlapping_start, overlapping_end) - self.master_weights.append(master_weight) - self.start_offsets.append(start_offset) - self.overlapping_areas.append(overlapping_area) - - # Create global buffer for grads reduce-scatter - self.grad_buffer = torch.empty( - [self.offsets[-1]], dtype=torch.float32, device=weights[0].device - ) - self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end] - - # Create global buffer for weights all-gather - if isinstance(self.weights[0], QuantizedTensor): - weight_buffer_dtype = torch.uint8 - else: - weight_buffer_dtype = weights[0].dtype - self.weight_buffer = torch.empty( - [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device - ) - self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] - - def step(self): - # ----------------------------------------------------------------------------------------- - # Step 1: Copy grads to the grad buffer - # ----------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - self.grad_buffer[start:end].copy_(weight.main_grad.view(-1)) - - # ----------------------------------------------------------------------------------------- - # Step 2: Grads reduce-scatter - # ----------------------------------------------------------------------------------------- - # Don't use reduce_scatter directly to explicitly control the reduce order. - # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG, - # group=self.dp_group) - buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)] - dist.all_gather(buffers, self.grad_buffer, group=self.dp_group) - for i in range(1, self.world_size): - buffers[0] += buffers[i] - rank_start = self.offsets[-1] // self.world_size * self.rank - rank_end = rank_start + self.offsets[-1] // self.world_size - self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end]) - self.grad_buffer_slice /= self.world_size - - # ----------------------------------------------------------------------------------------- - # Step 3: Update master weights - # ----------------------------------------------------------------------------------------- - for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas): - if master_weight is None: - # This weight's master weight is in other rank. - continue - grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]] - master_weight -= grad * self.lr - - # ----------------------------------------------------------------------------------------- - # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight - # ----------------------------------------------------------------------------------------- - if isinstance(self.weights[0], QuantizedTensor): - # FP8 weights case - for i in range(1, len(self.weights)): - assert isinstance(self.weights[i], QuantizedTensor) - cast_master_weights_to_fp8( - self.weights, self.master_weights, self.start_offsets, self.dp_group - ) - else: - # BF16 weights case - for weight, master_weight, start_offset in zip( - self.weights, self.master_weights, self.start_offsets - ): - if master_weight is None: - continue - start = start_offset - end = start_offset + master_weight.numel() - weight.data.view(-1)[start:end].copy_(master_weight) - - # ----------------------------------------------------------------------------------------- - # Step 5: Copy the updated weights (not all weights) to the weight buffer - # ----------------------------------------------------------------------------------------- - for i in range(len(self.weights)): - master_weight = self.master_weights[i] - if master_weight is None: - continue - start_offset = self.start_offsets[i] - if isinstance(self.weights[i], QuantizedTensor): - weight = _get_raw_data(self.weights[i]) - else: - weight = self.weights[i] - weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] - overlapping_start, overlapping_end = self.overlapping_areas[i] - self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) - - # ----------------------------------------------------------------------------------------- - # Step 6: Weight all-gather (FP8 or BF16) - # ----------------------------------------------------------------------------------------- - dist.all_gather_into_tensor( - self.weight_buffer, self.weight_buffer_slice, group=self.dp_group - ) - - # ----------------------------------------------------------------------------------------- - # Step 7: Copy the gathered weights from weight buffer to the actual weights - # ----------------------------------------------------------------------------------------- - for weight, offset in zip(self.weights, self.offsets[:-1]): - start = offset - end = offset + weight.numel() - if isinstance(weight, QuantizedTensor): - weight = _get_raw_data(weight) - weight.view(-1).data.copy_(self.weight_buffer[start:end]) - - -class MiniOptimizer: - - def __init__(self, weights, lr, dp_group): - self.world_size = dist.get_world_size(dp_group) - - self.weights = weights - self.lr = lr - self.dp_group = dp_group - - master_weights = [] - for weight in self.weights: - master_weights.append(weight.detach().float()) - self.master_weights = master_weights - - def step(self): - for weight, master_weight in zip(self.weights, self.master_weights): - main_grad = weight.main_grad - - # Don't use all-reduce directly to explicitly control the reduce order. - # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group) - buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)] - dist.all_gather(buffers, main_grad, group=self.dp_group) - for i in range(1, self.world_size): - buffers[0] += buffers[i] - main_grad.copy_(buffers[0]) - main_grad /= self.world_size - - master_weight -= main_grad * self.lr - weight.data.copy_(master_weight) - - -class MiniFSDP: - def __init__(self, weights, lr, dp_group): - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - self.weights = weights - self.lr = lr - self.dp_group = dp_group - - # Flatten the weights and pad to align with world size - raw_data_list = [ - _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1) - for w in weights - ] - if isinstance(weights[0], QuantizedTensor): - raw_data_list = [_get_raw_data(w).view(-1) for w in weights] - else: - raw_data_list = [w.view(-1) for w in weights] - self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) - - # Split flattened weights into shards - self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] - self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard) - shard_size = self.flatten_weight.size(0) // world_size - - # Map original tensors to flattened indices - tensor_indices = [] - cumulative_length = 0 - for tensor in raw_data_list: - length = tensor.size(0) - tensor_indices.append((cumulative_length, cumulative_length + length)) - cumulative_length += length - - # Build shard index mappings - self.weight_indices = [] - self.shard_indices = [] - for idx, (start, end) in enumerate(tensor_indices): - shard_start = rank * shard_size - shard_end = shard_start + shard_size - adjusted_end = min(shard_end, original_length) - - if start <= adjusted_end and end >= shard_start: - start_idx = max(start, shard_start) - end_idx = min(end, adjusted_end) - self.weight_indices.append((start_idx - start, end_idx - start)) - self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) - else: - self.weight_indices.append((None, None)) - self.shard_indices.append((None, None)) - - if isinstance(weights[idx], QuantizedTensor): - replace_raw_data( - weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) - ) - else: - weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) - - # Initialize local model weights and high-precision master weights - self.local_weights = [] - self.master_weights = [] - for i, weight in enumerate(self.weights): - weight_start, weight_end = self.weight_indices[i] - shard_start, shard_end = self.shard_indices[i] - if shard_start is not None and shard_end is not None: - local_weight_shard = self.local_weight_shard[shard_start:shard_end] - self.local_weights.append(local_weight_shard) - - if isinstance(weight, QuantizedTensor): - high_precision_init_val = weight.get_high_precision_init_val().view(-1) - master_weight_shard = high_precision_init_val.to(weight.device).float()[ - weight_start:weight_end - ] - else: - master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] - self.master_weights.append(master_weight_shard) - else: - self.local_weights.append(None) - self.master_weights.append(None) - setattr( - weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") - ) - - def _flatten_tensors_with_pad(self, tensors): - """ - Flatten the list of tensors and pad them to align with the world size. - - Args: - tensors (list): List of tensors to flatten. - - Returns: - tuple: Flattened tensor and its original length before padding. - """ - world_size = dist.get_world_size(self.dp_group) - - flatten_tensor = torch.cat(tensors) - original_length = flatten_tensor.size(0) - - padding_needed = (world_size - original_length % world_size) % world_size - if padding_needed > 0: - flatten_tensor = torch.cat( - [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)] - ) - - return flatten_tensor, original_length - - def zero_grad(self): - for weight in self.weights: - weight.grad = None - weight.main_grad.zero_() - - def step(self): - """ - Perform an optimization step for the distributed sharded model. - - This method includes: - 1. Gradient reduce-scatter: Synchronize gradients across all processes. - 2. Master weight update: Update high-precision master weights using local gradients. - 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. - 4. Weight synchronization: All-gather updated weights across all processes. - - Returns: - None - """ - # Step 1: Reduce-scatter the gradients - main_grad_buffer, _ = self._flatten_tensors_with_pad( - [weight.main_grad.view(-1) for weight in self.weights] - ) - main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype) - dist.reduce_scatter_tensor( - self.local_main_grad_shard, main_grad_buffer, group=self.dp_group - ) - - # Step 2: Update the master weights - for weight, master_weight, (shard_start, shard_end) in zip( - self.weights, self.master_weights, self.shard_indices - ): - if master_weight is None: - continue - - # Extract the local gradient shard for this weight - grad = self.local_main_grad_shard[shard_start:shard_end] - - # Update the master weight using gradient descent - master_weight -= grad * self.lr - - # Step 3: Cast master weights to FP8 or BF16 precision - if isinstance(self.weights[0], QuantizedTensor): - local_weights = [] - for local_weight in self.local_weights: - if local_weight is None: - local_weights.append(None) - continue - - local_weights.append(local_weight) - - cast_master_weights_to_fp8( - self.weights, - self.master_weights, - [idx[0] for idx in self.weight_indices], - self.dp_group, - local_weights, - ) - else: - for weight, master_weight in zip(self.local_weights, self.master_weights): - if master_weight is None: - continue - - # Copy updated master weights to local weights - weight.data.copy_(master_weight) - - # Step 4: All-gather updated weights across processes - dist.all_gather_into_tensor( - self.flatten_weight, self.local_weight_shard, group=self.dp_group - ) - - -def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group): - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - # Configuration constants - NUM_STEPS = 100 - SEED = 12345 - - torch.manual_seed(SEED) - torch.cuda.manual_seed(SEED) - - mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] - mock_group = mock_groups[rank] - - linear_kwargs = { - "params_dtype": torch.bfloat16, - "bias": False, - "fuse_wgrad_accumulation": False, - } - - # Create model with FP8 weights - with te.fp8.fp8_model_init( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - preserve_high_precision_init_val=True, - ): - model_fp8 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Create model with BF16 weights - model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Make sure the BF16 model and FP8 model have the same initial weights - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - high_precision_init_val = w_fp8.get_high_precision_init_val() - w.data.copy_(high_precision_init_val) - - optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group) - optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) - - for _ in range(100): - optimizer_fp8.zero_grad() - optimizer.zero_grad() - - inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) - ] - # Choose based on rank to make sure the inputs of different ranks are different. - x = inputs[rank] - - with te.fp8.fp8_autocast( - enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, - ): - y_fp8 = model_fp8(x) - - with te.fp8_autocast( - enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, - ): - y = model(x) - - targets = [torch.randn_like(y) for _ in range(world_size)] - # Choose based on rank to make sure the targets of different ranks are different. - target = targets[rank] - loss_fp8 = nn.MSELoss()(y_fp8, target) - loss = nn.MSELoss()(y, target) - - loss_fp8.backward() - loss.backward() - - optimizer_fp8.step() - optimizer.step() - - torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) - - print( - f"✅ Successfully validated FSDP {NUM_STEPS} training steps with" - f" {quantization} quantization" - ) - - -def _test_zero_1(dp_group): - """Make sure the implementation of zero-1 optimizer is correct""" - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - torch.manual_seed(12345) - torch.cuda.manual_seed(12345) - - weights = [ - torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"), - torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"), - torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"), - ] - - weights_1 = weights - weights_2 = [weight.clone() for weight in weights] - - lr = 1.0 - optimizer_1 = MiniZero_1(weights_1, lr, dp_group) - optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) - - for _ in range(100): - for w1, w2 in zip(weights_1, weights_2): - main_grads = [ - torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) - ] - # Choose based on rank to make sure the grads of different ranks are different. - main_grad = main_grads[rank] - w1.main_grad = main_grad - w2.main_grad = main_grad - - optimizer_1.step() - optimizer_2.step() - - for w1, w2 in zip(weights_1, weights_2): - torch.testing.assert_close(w1, w2, atol=0, rtol=0) - - -def quantization_recipe(quantization) -> Recipe: - """Quantization recipe setup""" - fp8_format = Format.HYBRID - if quantization == "fp8": - return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") - elif quantization == "fp8_cs": - return Float8CurrentScaling(fp8_format=fp8_format) - elif quantization == "fp8_block": - return Float8BlockScaling(fp8_format=fp8_format) - else: - raise ValueError(f"Unsupported quantization: {quantization}") - - -def _test_cast_master_weights_to_fp8(quantization, dp_group): - rank = dist.get_rank(dp_group) - world_size = dist.get_world_size(dp_group) - - torch.manual_seed(12345) - torch.cuda.manual_seed(12345) - - mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] - mock_group = mock_groups[rank] - - linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} - - # Create model with FP8 weights - with te.fp8.fp8_model_init( - enabled=quantization is not None, - recipe=quantization_recipe(quantization), - preserve_high_precision_init_val=True, - ): - model_fp8 = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Create model with BF16 weights - model = nn.Sequential( - te.Linear(128, 256 + 16, **linear_kwargs), - te.Linear(256 + 16, 256 * 3, **linear_kwargs), - te.Linear(256 * 3, 128, **linear_kwargs), - ) - - # Make sure the BF16 model and FP8 model have the same initial weights - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - high_precision_init_val = w_fp8.get_high_precision_init_val() - w.data.copy_(high_precision_init_val) - - # Allocate main_grads for each weight - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda") - w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") - - optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group) - optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) - - for i in range(100): - for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): - w_fp8.main_grad.zero_() - w.main_grad.zero_() - - inputs = [ - torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) - ] - # Choose based on rank to make sure the inputs of different ranks are different. - x = inputs[rank] - - with te.fp8.fp8_autocast( - enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, - ): - y_fp8 = model_fp8(x) - - with te.fp8_autocast( - enabled=quantization is not None, - fp8_recipe=quantization_recipe(quantization), - fp8_group=mock_group, - ): - y = model(x) - - targets = [torch.randn_like(y) for _ in range(world_size)] - # Choose based on rank to make sure the targets of different ranks are different. - target = targets[rank] - loss_fp8 = nn.MSELoss()(y_fp8, target) - loss = nn.MSELoss()(y, target) - - loss_fp8.backward() - loss.backward() - - optimizer_fp8.step() - optimizer.step() - - torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) - - -def main(argv=None, namespace=None): - WORLD_RANK = int(os.getenv("RANK", "0")) - WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) - LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) - LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) - - assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node - assert LOCAL_SIZE <= torch.cuda.device_count() - dist_init_kwargs = { - "backend": "nccl", - "rank": WORLD_RANK, - "world_size": WORLD_SIZE, - "timeout": datetime.timedelta(seconds=30), - } - dist_init_kwargs["init_method"] = "env://" - dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") - assert dist.is_nccl_available() - torch.cuda.set_device(LOCAL_RANK) - dist.init_process_group(**dist_init_kwargs) - - parser = argparse.ArgumentParser() - parser.add_argument( - "--quantization", type=str, default=None, choices=["fp8", "fp8_cs", "fp8_block"] - ) - args = parser.parse_args(argv, namespace) - - dp_group = dist.new_group(backend="nccl") - _test_zero_1(dp_group) - _test_cast_master_weights_to_fp8(args.quantization, dp_group) - _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group) - - dist.destroy_process_group() - return 0 - - -if __name__ == "__main__": - - sys.exit(main()) diff --git a/tests/pytorch/distributed/run_fsdp2_model.py b/tests/pytorch/distributed/run_fsdp2_model.py index 6f7fdab32..b9fe33593 100644 --- a/tests/pytorch/distributed/run_fsdp2_model.py +++ b/tests/pytorch/distributed/run_fsdp2_model.py @@ -1,6 +1,6 @@ #!/usr/bin/python3 # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -11,57 +11,74 @@ import argparse import transformer_engine.pytorch as te -from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.common.recipe import ( + Format, + DelayedScaling, + Float8CurrentScaling, + MXFP8BlockScaling, +) +from transformer_engine.pytorch import torch_version import torch import torch.distributed as dist +from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn, optim from torch.distributed import DeviceMesh from torch.distributed._composable.fsdp import fully_shard from torch.distributed.device_mesh import init_device_mesh +from transformer_engine.pytorch import QuantizedTensor from contextlib import nullcontext -from transformer_engine.pytorch import torch_version - -class SimpleNet(nn.Module): - def __init__(self, input_size, hidden_size, output_size): - super(SimpleNet, self).__init__() - self.fc1 = te.Linear(input_size, hidden_size) - self.fc2 = te.Linear(hidden_size, output_size) - def forward(self, x): - x = F.relu(self.fc1(x)) - x = self.fc2(x) - return x +LOCAL_RANK = None -def save_custom_attrs(module): - custom_attrs = {} - for name, param in module.named_parameters(): - attrs = vars(param) - custom_attrs[name] = {k: v for k, v in attrs.items()} - return custom_attrs - - -def restore_custom_attrs(module, custom_attrs): - for name, param in module.named_parameters(): - if name in custom_attrs: - for attr_name, attr_value in custom_attrs[name].items(): - setattr(param, attr_name, attr_value) +def dist_print(msg): + if LOCAL_RANK == 0: + print(msg) def _parse_args(argv=None, namespace=None): parser = argparse.ArgumentParser(description="Toy example for debugging fully_shard()") - parser.add_argument("--input-size", type=int, default=2048, help="Input size for the model") - parser.add_argument("--hidden-size", type=int, default=2048, help="Hidden layer size") - parser.add_argument("--output-size", type=int, default=2048, help="Output size for the model") - parser.add_argument("--batch-size", type=int, default=2048, help="Output size for the model") + parser.add_argument("--num-heads", type=int, default=8, help="Number of attn. heads") + parser.add_argument("--head-dim", type=int, default=64, help="Attention head size") + parser.add_argument("--batch-size", type=int, default=16, help="Batch size of input") + parser.add_argument("--seq-length", type=int, default=128, help="Sequence length of input") + parser.add_argument("--params-dtype", type=str, default="float32", help="Parameter dtype.") parser.add_argument( "--fp8-init", action="store_true", default=False, help="Initialize primary weights in FP8." ) + parser.add_argument( + "--recipe", + type=str, + default="mx_fp8_block_scaling", + help="Quantizer type.", + choices=["delayed_scaling", "current_scaling", "mx_fp8_block_scaling"], + ) + parser.add_argument( + "--layer-type", + type=str, + default="TransformerLayer", + choices=[ + "Linear", + "LayerNormLinear", + "LayerNormMLP", + "MultiheadAttention", + "TransformerLayer", + ], + help="Transformer Engine layer type", + ) + parser.add_argument("--num-layers", type=int, default=4, help="Number of layers in the model") parser.add_argument( "--iter", type=int, default=10, help="Number of iterations for forward pass" ) + parser.add_argument( + "--device", + type=str, + default="meta", + help="Device to run the model on.", + choices=["cuda", "meta"], + ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") # Adding hsdp_dim as a list argument, comma-separated parser.add_argument( @@ -76,10 +93,170 @@ def _parse_args(argv=None, namespace=None): return args -sub_modules_to_wrap = [te.Linear] +## Methods to help initialize the TE model in an FSDP2 setting +## with required configurations based on command line args +def get_te_layer_from_string(layer_name): + te_layer_types = [ + te.Linear, + te.LayerNormLinear, + te.LayerNormMLP, + te.MultiheadAttention, + te.TransformerLayer, + ] + te_layer_names = [layer.__name__ for layer in te_layer_types] + te_layer_map = dict(zip([name.lower() for name in te_layer_names], te_layer_types)) + if layer_name.lower() not in te_layer_map.keys(): + raise argparse.ArgumentTypeError( + f'"{layer_name}" is not a valid Transformer Engine layer, ' + f"please choose layer from {te_layer_names}." + ) + return te_layer_map[layer_name.lower()] + + +def get_recipe_from_string(recipe, fp8_format=Format.HYBRID): + if recipe == "delayed_scaling": + return DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + elif recipe == "current_scaling": + return Float8CurrentScaling(fp8_format=fp8_format) + elif recipe == "mx_fp8_block_scaling": + return MXFP8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError(f"Unknown quantizer type: {recipe}") + + +def init_te_model(config): + hidden_size = config.num_heads * config.head_dim + args = [hidden_size, hidden_size] + inp_shape = [config.seq_length, config.batch_size, hidden_size] + out_shape = [config.seq_length, config.batch_size, hidden_size] + if config.params_dtype == "float16": + params_dtype = torch.float16 + elif config.params_dtype == "bfloat16": + params_dtype = torch.bfloat16 + else: + params_dtype = torch.float32 + kwargs = { + "params_dtype": params_dtype, + } + kwargs["device"] = config.device + + layer_type = get_te_layer_from_string(config.layer_type) + # We are creating model in a way so that we can test both reshard_after_forward=True/False cases. + # more details below. + if layer_type in [te.MultiheadAttention, te.TransformerLayer]: + # For this case, we are creating a model that resemebles production use-cases + # wherein there are mltiple TransformerLayers in the model. And we would need + # to shard each transformer layer. Since each transformer layer is not a root module, + # FSDP2's fully_shard assigns reshard_after_forward=False for all parameters of the model. + args[1] *= 4 # FFN hidden size + args.append(config.num_heads) + kwargs["fuse_qkv_params"] = True + if layer_type is te.MultiheadAttention: + kwargs["input_layernorm"] = True + model = nn.Sequential(*[layer_type(*args, **kwargs) for _ in range(config.num_layers)]) + elif layer_type == te.LayerNormLinear: + # For this case, we are creating a model with just one LayerNormLinear layer + # so that the model itself is a root module, and FSDP2's fully_shard assigns + # reshard_after_forward=True for the parameters of these model. + args[1] *= 3 # QKV projection + out_shape[-1] *= 3 + model = layer_type(*args, **kwargs) + else: + model = layer_type(*args, **kwargs) + + return model, inp_shape, out_shape + + +def get_device_mesh(world_size, sharding_dims): + dist_print(f"sharding-dims:{sharding_dims}") + device_ids = list(range(world_size)) + if sharding_dims is None: # FSDP + mesh = DeviceMesh("cuda", device_ids) + elif len(sharding_dims) == 1: + assert sharding_dims[0] == world_size + mesh = DeviceMesh("cuda", device_ids) + elif len(sharding_dims) == 2: # HSDP + assert sharding_dims[0] * sharding_dims[1] == world_size + mesh = init_device_mesh( + "cuda", + (sharding_dims[0], sharding_dims[1]), + mesh_dim_names=("replicate", "shard"), + ) + else: + assert False + return mesh + + +def shard_model_with_fsdp2(model, mesh): + for child in model.children(): + fully_shard(child, mesh=mesh) + fully_shard(model, mesh=mesh) + return model + + +#### Methods to save the custom attributes of QuantizedTensors before sharding +#### them with FSDP2, and restore them after sharding. +def save_custom_attrs(module): + custom_attrs = {} + for name, param in module.named_parameters(): + if isinstance(param, QuantizedTensor): + # Ignore FP8 metadata attributes. Otherwise we will save duplicate copies + # for data/transpose FP8 tensors on top of FP8 tensors that FSDP2 will save. + ignore_keys = [key for key in param.__dict__.keys() if key.startswith("_")] + else: + ignore_keys = [] + attrs = vars(param) + custom_attrs[name] = {k: v for k, v in attrs.items() if k not in ignore_keys} + return custom_attrs + + +def restore_custom_attrs(module, custom_attrs): + for name, param in module.named_parameters(): + if name in custom_attrs: + for attr_name, attr_value in custom_attrs[name].items(): + setattr(param, attr_name, attr_value) + + +@torch.no_grad() +def test_fp8_fsdp2_allgather(model): + # Do manual allgather in fp32 and match against fp8 allgather done + # with fsdp2 + # FP32 manual weight allgather + fp32_allgathered_params = {} + for name, param in model.named_parameters(): + assert isinstance(param, DTensor) + local_tensor = param._local_tensor + device_mesh = param.device_mesh + dist_group = ( + device_mesh.get_group(mesh_dim="shard") + if device_mesh.ndim > 1 + else device_mesh.get_group() + ) + # Perform manual allgather on local_tensor. zeros_like will create hp tensor since torch_dispatch + # for local_tensor will go down the dequantization route. + gathered_tensor = [ + torch.zeros_like(local_tensor) for _ in range(dist.get_world_size(group=dist_group)) + ] + dist.all_gather(gathered_tensor, local_tensor.dequantize(), group=dist_group) + full_tensor = torch.cat(gathered_tensor, dim=0) + fp32_allgathered_params[name] = full_tensor + # FP8 allgather using FSDP2 + for module in model.modules(): + # Not all modules are wrapped/sharded with FSDP2. + if hasattr(module, "unshard"): + module.unshard() + # Make sure allgathered parameters match exactly + for name, param in model.named_parameters(): + assert torch.allclose(param.dequantize(), fp32_allgathered_params[name]) + # Revert model to original sharded state + for module in model.modules(): + # Not all modules are wrapped/sharded with FSDP2. + if hasattr(module, "reshard"): + module.reshard() def _train(args): + global LOCAL_RANK assert "TORCHELASTIC_RUN_ID" in os.environ WORLD_RANK = int(os.getenv("RANK", "0")) WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) @@ -105,73 +282,67 @@ def _train(args): # FP8 Configuration fp8_format = Format.HYBRID - fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max") + fp8_recipe = get_recipe_from_string(args.recipe, fp8_format) + build_model_context_args = {} if not args.fp8_init: # Build model context (FP8 init) build_model_context = nullcontext - build_model_context_args = {} - + else: from transformer_engine.pytorch import fp8_model_init build_model_context = fp8_model_init build_model_context_args["enabled"] = True + build_model_context_args["recipe"] = fp8_recipe - # Build the model with the specified context - with build_model_context(**build_model_context_args): - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) - else: - model = SimpleNet(args.input_size, args.hidden_size, args.output_size) - # Move the model to the correct device - - model.to(device) + dist_print(f"Memory before model init: {torch.cuda.memory_allocated(device)/1e6} MB") + # Create the model on the meta/cuda device as per args + with build_model_context(**build_model_context_args): + model, inp_shape, out_shape = init_te_model(args) + dist_print( + f"Memory after model init on device {args.device}:" + f" {torch.cuda.memory_allocated(device)/1e6} MB" + ) - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Applying FSDP fully_shard() to the model...") # Creating a DeviceMesh for fully_shard world_size = int(WORLD_SIZE) - device_ids = list(range(world_size)) - if LOCAL_RANK == 0: - print(f"sharding-dims:{args.sharding_dims}") # Setup the sharding mesh for FSDP/HSDP - if args.sharding_dims == None: # FSDP - mesh = DeviceMesh("cuda", device_ids) - elif len(args.sharding_dims) == 1: - assert args.sharding_dims[0] == device_ids[-1] + 1 - mesh = DeviceMesh("cuda", device_ids) - elif len(args.sharding_dims) == 2: # HSDP - assert args.sharding_dims[0] * args.sharding_dims[1] == device_ids[-1] + 1 - mesh = init_device_mesh( - "cuda", - (args.sharding_dims[0], args.sharding_dims[1]), - mesh_dim_names=("replicate", "shard"), - ) - else: - assert False - - # Apply FSDP/HSDP + mesh = get_device_mesh(world_size, args.sharding_dims) custom_attrs = save_custom_attrs(model) - for sub_module in model.modules(): - if any( - isinstance(sub_module, sub_module_to_wrap) for sub_module_to_wrap in sub_modules_to_wrap - ): - fully_shard(sub_module, mesh=mesh) - fully_shard(model, mesh=mesh) + model = shard_model_with_fsdp2(model, mesh) restore_custom_attrs(model, custom_attrs) + # model now has DTensors as its parameters + + if args.device == "meta": + # After FSDP2 has been applied, materialize and initialize the sharded parameters + # TE base.py's reset_parameters() handles DTensors with FP8 initialization + for module in model.modules(): + if hasattr(module, "reset_parameters"): + module.reset_parameters() + dist_print(f" Sharded parameters materialized and initialized on cuda device.") + + dist_print( + f"FSDP2 model in cuda, memory allocated: {torch.cuda.memory_allocated(device)/1e6} MB" + ) optimizer = optim.Adam(model.parameters(), lr=1e-3) for iteration in range(args.iter): # Zero the parameter gradients optimizer.zero_grad() - input_data = torch.randn(args.batch_size, args.input_size).to(device) - output = model(input_data) - target = torch.randn(args.batch_size, args.output_size).to(device) + input_data = torch.randn(inp_shape).to(device) + with te.autocast(enabled=True, recipe=fp8_recipe): + output = model(input_data) + target = torch.randn(out_shape).to(device) loss = F.mse_loss(output, target) loss.backward() optimizer.step() - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Iteration {iteration} completed.") + dist_print(f"Iteration {iteration} completed with loss {loss.item()}") + + # Some of the FSDP states are lazy initialized during FSDP forward pass + # so testing fp8 allgather at the end of the training loop. + if args.fp8_init: + test_fp8_fsdp2_allgather(model) # NOTE: In PyTorch < 2.6 there’s a teardown race where one rank may call # destroy_process_group() while other ranks still have in-flight NCCL ops, @@ -180,8 +351,6 @@ def _train(args): if torch_version() < (2, 6, 0): dist.barrier(device_ids=[torch.cuda.current_device()]) dist.destroy_process_group() - if LOCAL_RANK == 0: - print(f"Rank {LOCAL_RANK}: Done...") return 0 diff --git a/tests/pytorch/distributed/run_gemm_with_overlap.py b/tests/pytorch/distributed/run_gemm_with_overlap.py index 6d9e2f152..df0e4a216 100644 --- a/tests/pytorch/distributed/run_gemm_with_overlap.py +++ b/tests/pytorch/distributed/run_gemm_with_overlap.py @@ -18,9 +18,12 @@ from torch.distributed.elastic.multiprocessing.errors import record import transformer_engine.pytorch as te +from transformer_engine.pytorch import ( + Float8Tensor, + Float8Quantizer, + MXFP8Quantizer, +) import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.module.base import ( fill_userbuffers_buffer_for_all_gather, get_cublas_workspace_size_bytes, @@ -171,12 +174,12 @@ def _parse_args(argv=None, namespace=None): opts.p2p = True if opts.atomic: - if not te.fp8.check_fp8_support(): + if not te.is_fp8_available(): assert opts.quantization == "none", "Atomic GEMM is only supported in FP8." opts.quantization = "fp8" if opts.fp8_output: - assert ops.quantization == "fp8", "FP8 output is only supported with FP8 compute." + assert opts.quantization == "fp8", "FP8 output is only supported with FP8 compute." return opts diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 2a6e55b2c..b2bd6dd77 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -165,7 +165,7 @@ def _parse_args(argv=None, namespace=None): ) parser.add_argument("--seed", type=int, default=42, help="RNG seed.") parser.add_argument( - "--fp8", action="store_true", default=False, help="Enables the te.fp8_autocast() context." + "--fp8", action="store_true", default=False, help="Enables the te.autocast() context." ) parser.add_argument( "--quantization", @@ -438,7 +438,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, ) - with te.fp8_model_init(enabled=opts.fp8_init): + with te.quantized_model_init(enabled=opts.fp8_init): test_model = multi_module_model(opts.layer_type, opts.num_layers, *args, **kwargs) dist_print("Initialized test model...", debug=True) if WORLD_RANK == 0: @@ -450,7 +450,7 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): ref_args, ref_kwargs, _ = _get_layer_args( opts, nccl_world, opts.tp, num_layers=opts.num_layers, reference=True ) - with te.fp8_model_init(enabled=opts.fp8_init): + with te.quantized_model_init(enabled=opts.fp8_init): ref_model = multi_module_model(opts.layer_type, opts.num_layers, *ref_args, **ref_kwargs) dist_print("Initialized reference model...", debug=True) for test_param, ref_param in zip(test_model.parameters(), ref_model.parameters()): @@ -473,7 +473,9 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): layer_contexts = [ ( - partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world) + partial( + te.autocast, enabled=opts.fp8, recipe=fp8_recipe, amax_reduction_group=nccl_world + ) if opts.num_layers_at_start_in_bf16 <= i and i < (opts.num_layers - opts.num_layers_at_end_in_bf16) else nullcontext diff --git a/tests/pytorch/distributed/run_numerics.py b/tests/pytorch/distributed/run_numerics.py index 8a201b72d..183cef35c 100644 --- a/tests/pytorch/distributed/run_numerics.py +++ b/tests/pytorch/distributed/run_numerics.py @@ -11,6 +11,7 @@ import os import sys from functools import wraps +import math import transformer_engine.pytorch as te import torch @@ -23,10 +24,14 @@ DelayedScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, Format, Recipe, + QParams, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch import Float8CurrentScalingQuantizer, NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.distributed import gather_along_first_dim from run_layer_with_overlap import _compare_tensors if IS_HIP_EXTENSION: @@ -53,6 +58,14 @@ ) +def nvfp4_vanilla(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams() + nvfp4_recipe.fp4_quant_fwd_weight = QParams() + nvfp4_recipe.fp4_quant_bwd_grad = QParams() + return nvfp4_recipe + + # Quantization recipe setup def quantization_recipe() -> Recipe: if QUANTIZATION == "fp8": @@ -65,7 +78,9 @@ def quantization_recipe() -> Recipe: return Float8CurrentScaling() if QUANTIZATION == "fp8_block_scaling": return Float8BlockScaling() - return te.fp8.get_default_fp8_recipe() + if QUANTIZATION == "nvfp4": + return nvfp4_vanilla() + return te.quantization.get_default_fp8_recipe() if IS_HIP_EXTENSION: @@ -113,10 +128,14 @@ def main(argv=None, namespace=None): # Quantization scheme QUANTIZATION = args.quantization global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE - if QUANTIZATION in ("fp8", "mxfp8"): + if QUANTIZATION in ("fp8", "mxfp8", "nvfp4"): SEQ_LEN = 32 BATCH_SIZE = 32 HIDDEN_SIZE = 128 + # For fp8 block scaling, block size is 128, + # and to make low precision TP work, input tensor + # must be 128x128 divisible to be eligible for + # low precision All-Gather when needed elif QUANTIZATION == "fp8_block_scaling": SEQ_LEN = 128 BATCH_SIZE = 128 @@ -124,6 +143,7 @@ def main(argv=None, namespace=None): test_dict = [ test_quantizer, + test_quantized_all_gather, test_linear, test_layernorm, test_layernorm_linear, @@ -193,6 +213,9 @@ def _get_tolerances(dtype): # row parallel & sequence parallel, because we do the all_gather in backward pass if QUANTIZATION == "fp8_cs": return {"rtol": 0.4, "atol": 0.25} + elif QUANTIZATION == "nvfp4": + # TODO(zhongboz): investigate why the tolerance is so large + return {"rtol": 0.125, "atol": 0.12} elif QUANTIZATION is not None: return {"rtol": 0.125, "atol": 0.0625} @@ -314,15 +337,15 @@ def _apply_models( _alloc_main_grad(model_single_node, model_distributed) # for fuse_wgrad_accumulation=True input_single_node.requires_grad_() input_distributed.requires_grad_() - with te.fp8_autocast( + with te.autocast( enabled=QUANTIZATION is not None, - fp8_recipe=quantization_recipe(), + recipe=quantization_recipe(), ): output_single_node = model_single_node(input_single_node, **kwargs) - with te.fp8_autocast( + with te.autocast( enabled=QUANTIZATION is not None, - fp8_recipe=quantization_recipe(), - fp8_group=NCCL_WORLD, + recipe=quantization_recipe(), + amax_reduction_group=NCCL_WORLD, ): output_distributed = model_distributed(input_distributed, **kwargs) return output_single_node, output_distributed @@ -348,24 +371,36 @@ def _alloc_main_grad(model_single_node, model_distributed): ############################################### # Quantizer # ############################################### -def _construct_quantizer(quantizer_class, fp8_dtype, device, tp_group, tp_size): +def _construct_quantizer(quantizer_class, low_precision_dtype, device, tp_group, tp_size): """ quantizer is the reference quantizer on a single GPU. quantizer_dist is the distributed quantizer to be tested on multiple GPUs. """ if quantizer_class == Float8CurrentScalingQuantizer: quantizer_dist = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=True, amax_reduction_group=tp_group, ) quantizer = quantizer_class( - fp8_dtype=fp8_dtype, + fp8_dtype=low_precision_dtype, device=device, with_amax_reduction=False, ) return quantizer, quantizer_dist + elif quantizer_class == NVFP4Quantizer: + quantizer_dist = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=True, + amax_reduction_group=tp_group, + ) + quantizer = quantizer_class( + fp4_dtype=low_precision_dtype, + with_amax_reduction=False, + amax_reduction_group=None, + ) + return quantizer, quantizer_dist else: raise ValueError(f"Unsupported quantizer class: {quantizer_class}") @@ -436,6 +471,194 @@ def test_quantizer(): _test_quantizer(input_dtype, fp8_dtype) +############################################ +# Quantized All-Gather # +############################################ + + +def _ref_zero_padding_scale_inv(scale_inv, unpadded_shape): + """ + Zero padding the scale_inv. + scale_inv shape is the padded shape, but not zero padded + unpadded_shape is the original shape before padding + """ + dim0, dim1 = scale_inv.shape + unpadded_dim0, unpadded_dim1 = unpadded_shape + pad_dim0 = (128 - unpadded_dim0 % 128) % 128 + pad_dim1 = (4 - unpadded_dim1 % 4) % 4 + new_dim0 = unpadded_dim0 + pad_dim0 + new_dim1 = unpadded_dim1 + pad_dim1 + + assert dim0 == new_dim0 + assert dim1 == new_dim1 + + # return input if no padding is needed + if pad_dim0 == 0 and pad_dim1 == 0: + return scale_inv + + # unpad first to remove random bits from torch empty + scale_inv = scale_inv[:unpadded_dim0, :unpadded_dim1].contiguous() + # using torch padding + new_scale_inv = torch.nn.functional.pad( + scale_inv, (0, pad_dim1, 0, pad_dim0), mode="constant", value=0 + ) + + assert new_scale_inv.shape == (new_dim0, new_dim1) + + return new_scale_inv + + +def _get_unpadded_scale_inv_shape(input_shape, quantizer_cls, columnwise): + """ + Calculate the unpadded shape of the scale_inv tensor. + """ + M, K = 1, 1 + M = math.prod(input_shape[:-1]) + K = input_shape[-1] + + if quantizer_cls == NVFP4Quantizer: + if columnwise: + outer = K + inner = math.ceil(M / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + outer = M + inner = math.ceil(K / NVFP4_BLOCK_SCALING_SIZE) + return (outer, inner) + else: + raise ValueError(f"Unsupported quantizer class: {quantizer_cls}") + + +@run_distributed_test() +def _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls): + """Test the quantizer under distributed settings. + + Args: + input_dtype (torch.dtype): The data type of the input. + low_precision_dtype (tex.DType): The data type of the low precision, can be fp4 or fp8. + """ + + M, N = WORLD_SIZE * BATCH_SIZE, HIDDEN_SIZE // 2 + + # high precision input + x_hp_cpu = torch.randn((M, N), device="cpu").to(input_dtype) + # set one element of the input to a very large value, which doesn't live in rank 0 after the split + # to test the amax reduction on purpose + # x_hp_cpu[M - 1, N - 1] = 1e4 + + # get the unpadded shapes + unpadded_rowwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, False) + unpadded_columnwise_scale_inv_shape = _get_unpadded_scale_inv_shape((M, N), quantizer_cls, True) + + # rank 0 takes the full copy and quantize with GPU 0 for verification + if WORLD_RANK == 0: + x_hp_rank0 = x_hp_cpu.clone().detach().requires_grad_(True).to("cuda") + x_hp_local_rank = _shard_tensor(x_hp_cpu, WORLD_SIZE, 0)[WORLD_RANK] + + # Create quantizers + quantizer, quantizer_dist = _construct_quantizer( + quantizer_cls, low_precision_dtype, x_hp_local_rank.device, NCCL_WORLD, WORLD_SIZE + ) + + # quantize the entire input + if WORLD_RANK == 0: + x_low_precision_single = quantizer(x_hp_rank0) + + # run all-gather with a quantizer as input for quantized all-gather + x_low_precision_total, _ = gather_along_first_dim( + x_hp_local_rank, NCCL_WORLD, async_op=False, quantizer=quantizer_dist + ) + + # check the outputs + if WORLD_RANK == 0: + # assert all data and scale_inv are the same + torch.testing.assert_close( + x_low_precision_single._rowwise_data, + x_low_precision_total._rowwise_data, + rtol=0.0, + atol=0.0, + ) + # check the rowwise scale without any padding + unpad_dim0, unpad_dim1 = unpadded_rowwise_scale_inv_shape + unpadded_rowwise_scale_inv_ref = x_low_precision_single._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_rowwise_scale_inv = x_low_precision_total._rowwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_rowwise_scale_inv_ref, + unpadded_rowwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._rowwise_scale_inv, unpadded_rowwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + x_low_precision_single._columnwise_data, + x_low_precision_total._columnwise_data, + rtol=0.0, + atol=0.0, + ) + unpad_dim0, unpad_dim1 = unpadded_columnwise_scale_inv_shape + unpadded_columnwise_scale_inv_ref = x_low_precision_single._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + unpadded_columnwise_scale_inv = x_low_precision_total._columnwise_scale_inv[ + :unpad_dim0, :unpad_dim1 + ] + torch.testing.assert_close( + unpadded_columnwise_scale_inv_ref, + unpadded_columnwise_scale_inv, + rtol=0.0, + atol=0.0, + ) + torch.testing.assert_close( + _ref_zero_padding_scale_inv( + x_low_precision_single._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + _ref_zero_padding_scale_inv( + x_low_precision_total._columnwise_scale_inv, unpadded_columnwise_scale_inv_shape + ), + rtol=0.0, + atol=0.0, + ) + + +def test_quantized_all_gather(): + """ + Run quantized all-gather tests with various configurations. + """ + # skip this test for other quantization schemes + is_nvfp4 = QUANTIZATION == "nvfp4" + # add other recipes for testing if needed + if not is_nvfp4: + return + + input_dtypes = [torch.bfloat16] + fp4_dtype = [tex.DType.kFloat4E2M1] + fp8_dtype = [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2] + quantizer_cls_nvfp4 = [NVFP4Quantizer] + # add FP8 quantizers if needed + quantizer_cls_fp8 = [] + + low_precisio_dtypes = fp4_dtype if is_nvfp4 else fp8_dtype + quantizer_cls_list = quantizer_cls_nvfp4 if is_nvfp4 else quantizer_cls_fp8 + + for quantizer_cls in quantizer_cls_list: + for input_dtype in input_dtypes: + for low_precision_dtype in low_precisio_dtypes: + _test_quantized_all_gather(input_dtype, low_precision_dtype, quantizer_cls) + + ############################################ # Linear # ############################################ @@ -536,10 +759,11 @@ def test_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"delay_wgrad_compute": True}, {"save_original_input": True}, ] + for kwargs in kwargs_list: if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": continue @@ -715,11 +939,12 @@ def test_layernorm_linear(): {"init_method": _constant}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"zero_centered_gamma": False}, {"return_layernorm_output": True}, {"delay_wgrad_compute": True}, ] + for kwargs in kwargs_list: for parallel_mode in ["column"]: for sequence_parallel in [False, True]: @@ -821,7 +1046,7 @@ def test_layernorm_mlp(): {"normalization": "RMSNorm"}, {"zero_centered_gamma": True}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"activation": "relu"}, {"fuse_wgrad_accumulation": True}, {"return_bias": True}, @@ -919,7 +1144,7 @@ def test_transformer_layer(): {"fuse_qkv_params": True, "fuse_wgrad_accumulation": True}, {"qkv_weight_interleaved": False}, {"bias": False}, - {"params_dtype": torch.float16}, + {"params_dtype": torch.float16 if QUANTIZATION != "nvfp4" else torch.bfloat16}, {"fuse_qkv_params": True}, {"activation": "relu"}, ] diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py new file mode 100644 index 000000000..3605b3c70 --- /dev/null +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -0,0 +1,756 @@ +#!/usr/bin/python3 + +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import argparse +import datetime +import os +import sys +from functools import wraps + +import transformer_engine.pytorch as te +import torch +from torch import nn +import torch.distributed as dist +from transformer_engine.common.recipe import ( + NVFP4BlockScaling, + Recipe, + QParams, + CustomRecipe, +) +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE +from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import utils +from run_layer_with_overlap import _compare_tensors + + +BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE = 128, 256, 128 +WORLD_RANK, WORLD_SIZE = None, None +NCCL_WORLD = None +LOSS_FN = nn.MSELoss() +QUANTIZATION = None + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def get_nvfp4_quantizer_factory(): + """ + Create a quantizer factory for NVFP4 reference implementation. + + This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization + enabled. + + Returns: + A factory function that takes a role string and returns a quantizer instance + """ + + def factory(role): + if role == "linear_input": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, # RHT enabled for input + ) + elif role == "linear_weight": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), # 2D quantization for weight + pow_2_scales=False, + with_rht=False, + ) + elif role == "linear_output": + # Output quantization not used + return None + elif role == "linear_grad_output": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, # RHT enabled for grad_output + ) + elif role == "linear_grad_input": + # Grad input quantization not used + return None + else: + # For any other roles, return None + return None + + return factory + + +# Quantization recipe setup +def quantization_recipe() -> Recipe: + if QUANTIZATION == "nvfp4": + return nvfp4_rht_and_2d_quantization() + raise ValueError(f"Unsupported quantization: {QUANTIZATION}") + + +def quantization_reference_recipe() -> Recipe: + """Create reference recipe using CustomRecipe with NVFP4 quantizer factory.""" + if QUANTIZATION == "nvfp4": + nvfp4_ref_factory = get_nvfp4_quantizer_factory() + return CustomRecipe(qfactory=nvfp4_ref_factory) + raise ValueError(f"Unsupported quantization for reference: {QUANTIZATION}") + + +def main(argv=None, namespace=None): + global WORLD_RANK, WORLD_SIZE, NCCL_WORLD, QUANTIZATION, BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + + NCCL_WORLD = dist.new_group(backend="nccl") + + WORLD_SIZE = dist.get_world_size() + + parser = argparse.ArgumentParser() + parser.add_argument("--quantization", type=str, default=None) + parser.add_argument("--batch-size", type=int, default=32) + parser.add_argument("--hidden-size", type=int, default=128) + parser.add_argument("--out-size", type=int, default=128) + args = parser.parse_args(argv, namespace) + + # Quantization scheme + QUANTIZATION = args.quantization + BATCH_SIZE = args.batch_size + HIDDEN_SIZE = args.hidden_size + OUT_SIZE = args.out_size + + test_dict = [ + test_linear, + test_layernorm_linear, + ] + + for test in test_dict: + test() + dist.destroy_process_group() + return 0 + + +def run_distributed_test(test_name=None): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + name = test_name if test_name is not None else func.__name__ + + dist_print(f"Starting test {name} with args {args} and {kwargs}") + torch.cuda.set_device(WORLD_RANK) + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + func(*args, **kwargs) + + dist.barrier() + dist_print(f"Passed test {name}") + + return wrapper + + return decorator + + +def dist_print(msg, src=None, end="\n", error=False): + stream = sys.stderr if error else sys.stdout + if WORLD_RANK == (0 if src is None else src): + stream.write(f"[rank{WORLD_RANK}] {msg}{end}\n") + + +############################################ +# Linear # +############################################ +class TestDistributedLinearBase: + @staticmethod + def _prepare_data( + batch_size, hidden_size, out_size, use_bias=True, seed=0, dtype=torch.float32 + ): + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((batch_size, hidden_size), dtype=dtype, device="cuda") + w = torch.randn((out_size, hidden_size), dtype=dtype, device="cuda") + bias = torch.randn((out_size), dtype=dtype, device="cuda") if use_bias else None + gradient = torch.randn((batch_size, out_size), dtype=dtype, device="cuda") + + return x, w, bias, gradient + + @staticmethod + def _shard_tensor(x, world_size, axis): + split_size = x.size()[axis] // world_size + split_tensor = torch.split(x, split_size, axis) + out = [] + for tensor in split_tensor: + out.append(tensor.detach().clone().requires_grad_(x.requires_grad)) + return out + + @staticmethod + def _gather_tensor(local, world_size, tp_group, concat_dim): + out_list = [torch.zeros_like(local) for _ in range(world_size)] + torch.distributed.all_gather(out_list, local, tp_group) + return torch.cat(out_list, dim=concat_dim) + + @staticmethod + def _all_reduce_tensor(local, world_size, tp_group): + if world_size == 1: + return local + handle = torch.distributed.all_reduce(local, group=tp_group, async_op=False) + return local + + @staticmethod + def _get_sum_abs_error(a, b): + return torch.sum(torch.abs(a - b)) + + @staticmethod + def _get_mean_abs_relative_error(a, b): + error = torch.where(b == 0, torch.ne(a, b), torch.abs((a - b) / b)) + return torch.mean(error) + + @classmethod + def run_linear_preprocess_parallel( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_size=1, + rank=0, + ): + if tp_size > 1: + if parallel_mode == "column": + # split w in N dim, which should be axis 0 + w = cls._shard_tensor(w, tp_size, 0)[rank] + bias = cls._shard_tensor(bias, tp_size, 0)[rank] if bias is not None else None + # split gradient in N dim, which should be axis 1 + gradient = cls._shard_tensor(gradient, tp_size, 1)[rank] + if sequence_parallel: + # split x in M dim, which should be axis 0 + x = cls._shard_tensor(x, tp_size, 0)[rank] + # row parallel, split x in k dim, which should be axis 1, split w in k dim, should be axis 1 + if parallel_mode == "row": + # split x in K dim, which should be axis 1 + x = cls._shard_tensor(x, tp_size, 1)[rank] + # split w in K dim, which should be axis 1 + w = cls._shard_tensor(w, tp_size, 1)[rank] + if sequence_parallel: + # split gradient in M dim, which should be axis 0 + gradient = cls._shard_tensor(gradient, tp_size, 0)[rank] + return x, w, bias, gradient + + @classmethod + def run_linear_postprocess_parallel( + cls, + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ): + if tp_size > 1: + if parallel_mode == "column": + # gather y_q in N dim, which should be axis 1 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 1) + # gather wgrad in N dim, which should be axis 0 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 0) + # gather bgrad in N dim, which should be axis 0 + bgrad = ( + cls._gather_tensor(bgrad, tp_size, tp_group, 0) if bgrad is not None else None + ) + if sequence_parallel: + # gather dgrad in M dim, which should be axis 0 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 0) + if parallel_mode == "row": + # gather dgrad in K dim, which should be axis 1 + dgrad = cls._gather_tensor(dgrad, tp_size, tp_group, 1) + # gather wgrad in K dim, which should be axis 1 + wgrad = cls._gather_tensor(wgrad, tp_size, tp_group, 1) + if sequence_parallel: + # gather y_q in M dim, which should be axis 0 + y_q = cls._gather_tensor(y_q, tp_size, tp_group, 0) + # we need to sum bias gradient when using TP + SP + bgrad = ( + cls._all_reduce_tensor(bgrad, tp_size, tp_group) + if bgrad is not None + else None + ) + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_one_step( + cls, layer, x, gradient, is_first_microbatch=None, fuse_wgrad_accumulation=False + ): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + if isinstance(layer, te.Linear): + # Kitchen Linear + y_q = layer.forward(x, is_first_microbatch=is_first_microbatch) + else: + # the default torch.nn.Linear + y_q = layer(x) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + bgrad = ( + layer._parameters["bias"].grad + if layer._parameters.get("bias", None) is not None + else None + ) + assert "weight" in layer._parameters + if fuse_wgrad_accumulation: + wgrad = layer._parameters["weight"].main_grad + assert layer._parameters["weight"].grad is None + else: + wgrad = layer._parameters["weight"].grad + + return y_q, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation=False, + ): + """ + Run multiple steps of linear layer and collect results. + """ + + y_q_list, dgrad_list, wgrad_list = [], [], [] + bgrad_list = [] if layer._parameters.get("bias", None) is not None else None + + for i in range(run_num_steps): + x_i = (x + i).clone().detach().requires_grad_(True) + # run_linear_one_step + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step( + layer, + x_i, + gradient, + is_first_microbatch=(i == 0) if enable_weight_cache else None, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + # Collect results + y_q_list.append(y_q.detach().clone()) + dgrad_list.append(dgrad.detach().clone()) + wgrad_list.append(wgrad.detach().clone()) + if bgrad_list is not None and bgrad is not None: + bgrad_list.append(bgrad.detach().clone()) + + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + + @classmethod + def run_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + fuse_wgrad_accumulation=False, + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = te.Linear( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + ) + + layer = layer.to("cuda") + + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + if fuse_wgrad_accumulation: + assert ( + run_num_steps > 1 + ), "Fused weight gradient accumulation requires run_num_steps > 1" + layer.weight.main_grad = torch.zeros_like(layer.weight) + + # Run one step or multiple steps + if run_num_steps == 1: + y_q, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + else: + y_q, dgrad, wgrad, bgrad = cls.run_linear_multiple_steps( + layer, + x, + gradient, + run_num_steps, + enable_weight_cache, + fuse_wgrad_accumulation, + ) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'row' or 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + + QUANTIZATION options: nvfp4 <=> custom nvfp4 as a reference + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + fuse_wgrad_accumulation = kwargs.get("fuse_wgrad_accumulation", False) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # turn on weight quantization cache when fusing wgrad accumulation + enable_weight_cache = fuse_wgrad_accumulation + run_num_steps = 1 if not fuse_wgrad_accumulation else 5 + + x, w, bias, gradient = TestDistributedLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.autocast(enabled=True, recipe=recipe): + y_q, dgrad, wgrad, bgrad = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=1 if not fuse_wgrad_accumulation else 5, + enable_weight_cache=fuse_wgrad_accumulation, + ) + + # run the reference + reference_recipe = quantization_reference_recipe() + with te.autocast(enabled=True, recipe=reference_recipe): + y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = TestDistributedLinearBase.run_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + fuse_wgrad_accumulation=fuse_wgrad_accumulation, + run_num_steps=run_num_steps, + enable_weight_cache=enable_weight_cache, + ) + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_linear(): + """Run linear layer tests with various configurations.""" + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + if kwargs.get("save_original_input", False) and QUANTIZATION == "fp8": + continue + for parallel_mode in ["column", "row"]: + for sequence_parallel in [False, True]: + _test_linear(parallel_mode, sequence_parallel, **kwargs) + + +############################################ +# LayerNormLinear # +############################################ +class TestDistributedLayerNormLinearBase(TestDistributedLinearBase): + + @classmethod + def run_linear_one_step(cls, layer, x, gradient, is_first_microbatch=None): + # reset gradients + layer.zero_grad() + x.grad = None + + # Forward pass + y_q, ln_out = layer.forward(x, is_first_microbatch=is_first_microbatch) + + # Backward pass + y_q.backward(gradient) + + # Collect gradients + dgrad = x.grad + + parameters = layer._parameters + + # bias and weight gradients + bgrad = parameters["bias"].grad if parameters.get("bias", None) is not None else None + assert "weight" in parameters + wgrad = parameters["weight"].grad + + return y_q, ln_out, dgrad, wgrad, bgrad + + @classmethod + def run_linear_multiple_steps( + cls, layer, x, gradient, run_num_steps, enable_weight_cache, fuse_wgrad_accumulation=False + ): + # raise error, no test case for multiple steps for now + raise NotImplementedError("LayerNormLinear does not support test multiple steps for now") + + @classmethod + def run_layernorm_linear( + cls, + x, + w, + bias, + gradient, + parallel_mode=None, + sequence_parallel=False, + tp_group=None, + tp_size=1, + rank=0, + run_num_steps=1, + enable_weight_cache=False, + LayerNormLinearClass=te.LayerNormLinear, + normalization="LayerNorm", + ): + """ + If Model parallel, split inputs for a given rank and return the gathered output and gradients, so that they can be compared with + the reference single GPU run. + """ + # clone inputs and move to current device + # w has shape [N, K], x has shape [M, K], gradient has shape [M, N] + x = x.clone().detach().requires_grad_(True).to("cuda") + w = w.clone().detach().to("cuda") + gradient = gradient.clone().detach().to("cuda") + bias = bias.clone().detach().to("cuda") if bias is not None else None + in_features = x.shape[1] + out_features = w.shape[0] + + # If Model parallel: split inputs for a given rank + x, w, bias, gradient = cls.run_linear_preprocess_parallel( + x, w, bias, gradient, parallel_mode, sequence_parallel, tp_size, rank + ) + + # set data types + params_dtype = x.dtype + + # Create linear layer and copy weights + layer = LayerNormLinearClass( + in_features, + out_features, + bias=bias is not None, + params_dtype=params_dtype, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=tp_group, + tp_size=tp_size, + normalization=normalization, + return_layernorm_output=True, + ) + + layer = layer.to("cuda") + + # Copy weights + # kitchen_linear has different parameter names + with torch.no_grad(): + layer.weight.copy_(w) + if bias is not None: + layer.bias.copy_(bias) + + # Run one step + y_q, ln_out, dgrad, wgrad, bgrad = cls.run_linear_one_step(layer, x, gradient) + + # If Model parallel: gather output and gradients from all ranks + y_q, dgrad, wgrad, bgrad = cls.run_linear_postprocess_parallel( + y_q, + dgrad, + wgrad, + bgrad, + parallel_mode, + sequence_parallel, + tp_size, + tp_group, + ) + + return y_q, ln_out, dgrad, wgrad, bgrad + + +@run_distributed_test() +def _test_layernorm_linear(parallel_mode=None, sequence_parallel=False, **kwargs): + """Test the linear layer with specified parallel mode and sequence parallelization. + + Args: + parallel_mode (str): 'column' parallelism. + sequence_parallel (bool): Enable sequence parallelism if True. + kwargs (dict): Additional arguments for the linear layer. + """ + params_dtype = torch.bfloat16 + use_bias = kwargs.get("bias", True) + seed = torch.initial_seed() + recipe = quantization_recipe() + + # run multiple steps currently not supported for LayerNormLinear + run_num_steps = 1 + + x, w, bias, gradient = TestDistributedLayerNormLinearBase._prepare_data( + BATCH_SIZE, HIDDEN_SIZE, OUT_SIZE, use_bias=use_bias, seed=seed, dtype=params_dtype + ) + + # run the recipe under test + with te.autocast(enabled=True, recipe=recipe): + y_q, ln_out, dgrad, wgrad, bgrad = TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + + # run the reference + reference_recipe = quantization_reference_recipe() + with te.autocast(enabled=True, recipe=reference_recipe): + y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = ( + TestDistributedLayerNormLinearBase.run_layernorm_linear( + x, + w, + bias, + gradient, + parallel_mode=parallel_mode, + sequence_parallel=sequence_parallel, + tp_group=NCCL_WORLD, + tp_size=WORLD_SIZE, + rank=WORLD_RANK, + run_num_steps=run_num_steps, + enable_weight_cache=False, + ) + ) + + # compare results, zero tolerance + if WORLD_RANK == 0: + torch.testing.assert_close(y_q, y_q_ref, atol=0, rtol=0, msg="Output mismatch") + torch.testing.assert_close(ln_out, ln_out_ref, atol=0, rtol=0, msg="LN output mismatch") + torch.testing.assert_close(dgrad, dgrad_ref, atol=0, rtol=0, msg="Dgrad mismatch") + torch.testing.assert_close(wgrad, wgrad_ref, atol=0, rtol=0, msg="Wgrad mismatch") + if bgrad is not None and bgrad_ref is not None: + torch.testing.assert_close(bgrad, bgrad_ref, atol=0, rtol=0, msg="Bgrad mismatch") + + +def test_layernorm_linear(): + kwargs_list = [ + {"bias": False}, + ] + + for kwargs in kwargs_list: + for parallel_mode in ["column"]: + for sequence_parallel in [False, True]: + _test_layernorm_linear(parallel_mode, sequence_parallel, **kwargs) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py index fd802c910..0ff98e6cb 100644 --- a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py +++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py @@ -2,39 +2,744 @@ # # See LICENSE for license information. +import argparse +import datetime import os import subprocess -from pathlib import Path +import sys +import pathlib import pytest import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from torch import nn +import torch.distributed as dist +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Float8BlockScaling, + Format, + Recipe, +) +import transformer_engine.pytorch as te +from transformer_engine.pytorch import ( + is_fp8_available, + is_fp8_block_scaling_available, + QuantizedTensor, + Float8Tensor, + Float8BlockwiseQTensor, +) +from transformer_engine.pytorch.tensor import cast_master_weights_to_fp8 +from transformer_engine.pytorch.tensor.utils import post_all_gather_processing, replace_raw_data -if torch.cuda.device_count() < 2: - pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) +def _get_quantization_recipe(quantization) -> Recipe: + """Quantization recipe setup""" + fp8_format = Format.HYBRID + if quantization == "fp8": + return DelayedScaling(fp8_format=fp8_format, amax_history_len=32, amax_compute_algo="max") + elif quantization == "fp8_cs": + return Float8CurrentScaling(fp8_format=fp8_format) + elif quantization == "fp8_block": + return Float8BlockScaling(fp8_format=fp8_format) + else: + raise ValueError(f"Unsupported quantization: {quantization}") + + +def _get_raw_data(quantized_tensor): + """Get the underlying data of a quantized tensor, used in zero-1 optimizer""" + if isinstance(quantized_tensor, Float8Tensor): + assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute" + assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8" + return quantized_tensor._data + elif isinstance(quantized_tensor, Float8BlockwiseQTensor): + assert hasattr( + quantized_tensor, "_rowwise_data" + ), "Float8BlockwiseQTensor does not have _rowwise_data attribute" + assert ( + quantized_tensor._rowwise_data.dtype == torch.uint8 + ), "Float8BlockwiseQTensor _rowwise_data must be uint8" + return quantized_tensor._rowwise_data + else: + raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}") + + +class MiniOptimizer: + + def __init__(self, weights, lr, dp_group): + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + + master_weights = [] + for weight in self.weights: + master_weights.append(weight.detach().float()) + self.master_weights = master_weights + + def step(self): + for weight, master_weight in zip(self.weights, self.master_weights): + main_grad = weight.main_grad + + # Don't use all-reduce directly to explicitly control the reduce order. + # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group) + buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)] + dist.all_gather(buffers, main_grad, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + main_grad.copy_(buffers[0]) + main_grad /= self.world_size + + master_weight -= main_grad * self.lr + weight.data.copy_(master_weight) + + +class MiniZero_1: + """A mini zero-1 optimizer implementation, just used for this test""" + + def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=False): + self.rank = dist.get_rank(dp_group) + self.world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + self.manual_post_all_gather_processing = manual_post_all_gather_processing + + # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer + self.offsets = [0] + for weight in self.weights: + self.offsets.append(self.offsets[-1] + weight.numel()) + + # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may + # not be the end range of the last weight. + if self.offsets[-1] % self.world_size != 0: + self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size + + self.master_weights = [] + # The start offset of the master weight in the weight + self.start_offsets = [] + # The overlapping area of the weight and this rank's local buffer + self.overlapping_areas = [] + + # The start and end of this rank's local buffer in the global buffer + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + + for weight, offset in zip(self.weights, self.offsets[:-1]): + if offset >= rank_end or (offset + weight.numel()) <= rank_start: + # This weight is not in this rank's local buffer + master_weight = None + start_offset = None + overlapping_area = None + else: + overlapping_start = max(rank_start, offset) + overlapping_end = min(rank_end, offset + weight.numel()) + length = overlapping_end - overlapping_start + start_offset = overlapping_start - offset + if isinstance(weight, QuantizedTensor): + # If weight is a FP8 tensor, we need to use the original high precision version + # to initialize the master weight. + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight = high_precision_init_val.to(weight.device).float()[ + start_offset : start_offset + length + ] + else: + master_weight = ( + weight.detach().view(-1).float()[start_offset : start_offset + length] + ) + overlapping_area = (overlapping_start, overlapping_end) + self.master_weights.append(master_weight) + self.start_offsets.append(start_offset) + self.overlapping_areas.append(overlapping_area) + + # Create global buffer for grads reduce-scatter + self.grad_buffer = torch.empty( + [self.offsets[-1]], dtype=torch.float32, device=weights[0].device + ) + self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end] + + # Create global buffer for weights all-gather + if isinstance(self.weights[0], QuantizedTensor): + weight_buffer_dtype = torch.uint8 + else: + weight_buffer_dtype = weights[0].dtype + self.weight_buffer = torch.empty( + [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device + ) + self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end] + + def step(self): + # ----------------------------------------------------------------------------------------- + # Step 1: Copy grads to the grad buffer + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + self.grad_buffer[start:end].copy_(weight.main_grad.view(-1)) + + # ----------------------------------------------------------------------------------------- + # Step 2: Grads reduce-scatter + # ----------------------------------------------------------------------------------------- + # Don't use reduce_scatter directly to explicitly control the reduce order. + # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG, + # group=self.dp_group) + buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)] + dist.all_gather(buffers, self.grad_buffer, group=self.dp_group) + for i in range(1, self.world_size): + buffers[0] += buffers[i] + rank_start = self.offsets[-1] // self.world_size * self.rank + rank_end = rank_start + self.offsets[-1] // self.world_size + self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end]) + self.grad_buffer_slice /= self.world_size + + # ----------------------------------------------------------------------------------------- + # Step 3: Update master weights + # ----------------------------------------------------------------------------------------- + for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas): + if master_weight is None: + # This weight's master weight is in other rank. + continue + grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]] + master_weight -= grad * self.lr + + # ----------------------------------------------------------------------------------------- + # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight + # ----------------------------------------------------------------------------------------- + if isinstance(self.weights[0], QuantizedTensor): + # FP8 weights case + for i in range(1, len(self.weights)): + assert isinstance(self.weights[i], QuantizedTensor) + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + self.start_offsets, + self.dp_group, + manual_post_all_gather_processing=self.manual_post_all_gather_processing, + ) + else: + # BF16 weights case + for weight, master_weight, start_offset in zip( + self.weights, self.master_weights, self.start_offsets + ): + if master_weight is None: + continue + start = start_offset + end = start_offset + master_weight.numel() + weight.data.view(-1)[start:end].copy_(master_weight) + + # ----------------------------------------------------------------------------------------- + # Step 5: Copy the updated weights (not all weights) to the weight buffer + # ----------------------------------------------------------------------------------------- + for i in range(len(self.weights)): + master_weight = self.master_weights[i] + if master_weight is None: + continue + start_offset = self.start_offsets[i] + if isinstance(self.weights[i], QuantizedTensor): + weight = _get_raw_data(self.weights[i]) + else: + weight = self.weights[i] + weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()] + overlapping_start, overlapping_end = self.overlapping_areas[i] + self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice) + + # ----------------------------------------------------------------------------------------- + # Step 6: Weight all-gather (FP8 or BF16) + # ----------------------------------------------------------------------------------------- + dist.all_gather_into_tensor( + self.weight_buffer, self.weight_buffer_slice, group=self.dp_group + ) + + # ----------------------------------------------------------------------------------------- + # Step 7: Copy the gathered weights from weight buffer to the actual weights + # ----------------------------------------------------------------------------------------- + for weight, offset in zip(self.weights, self.offsets[:-1]): + start = offset + end = offset + weight.numel() + if isinstance(weight, QuantizedTensor): + weight = _get_raw_data(weight) + weight.view(-1).data.copy_(self.weight_buffer[start:end]) + + if self.manual_post_all_gather_processing: + quantized_weights = [ + weight for weight in self.weights if isinstance(weight, QuantizedTensor) + ] + post_all_gather_processing(quantized_weights) + + +class MiniFSDP: + def __init__(self, weights, lr, dp_group, manual_post_all_gather_processing=False): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + self.weights = weights + self.lr = lr + self.dp_group = dp_group + self.manual_post_all_gather_processing = manual_post_all_gather_processing + + # Flatten the weights and pad to align with world size + if isinstance(weights[0], QuantizedTensor): + raw_data_list = [_get_raw_data(w).view(-1) for w in weights] + else: + raw_data_list = [w.view(-1) for w in weights] + self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list) + + # Split flattened weights into shards + self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank] + self.local_main_grad_shard = torch.zeros_like( + self.local_weight_shard, dtype=torch.float32, device="cuda" + ) + shard_size = self.flatten_weight.size(0) // world_size + + # Map original tensors to flattened indices + tensor_indices = [] + cumulative_length = 0 + for tensor in raw_data_list: + length = tensor.size(0) + tensor_indices.append((cumulative_length, cumulative_length + length)) + cumulative_length += length + + # Build shard index mappings + self.weight_indices = [] + self.shard_indices = [] + for idx, (start, end) in enumerate(tensor_indices): + shard_start = rank * shard_size + shard_end = shard_start + shard_size + adjusted_end = min(shard_end, original_length) + + if start <= adjusted_end and end >= shard_start: + start_idx = max(start, shard_start) + end_idx = min(end, adjusted_end) + self.weight_indices.append((start_idx - start, end_idx - start)) + self.shard_indices.append((start_idx - shard_start, end_idx - shard_start)) + else: + self.weight_indices.append((None, None)) + self.shard_indices.append((None, None)) + + if isinstance(weights[idx], QuantizedTensor): + replace_raw_data( + weights[idx], self.flatten_weight[start:end].view(weights[idx].shape) + ) + else: + weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape) + + # Initialize local model weights and high-precision master weights + self.local_weights = [] + self.master_weights = [] + for i, weight in enumerate(self.weights): + weight_start, weight_end = self.weight_indices[i] + shard_start, shard_end = self.shard_indices[i] + if shard_start is not None and shard_end is not None: + local_weight_shard = self.local_weight_shard[shard_start:shard_end] + self.local_weights.append(local_weight_shard) + + if isinstance(weight, QuantizedTensor): + high_precision_init_val = weight.get_high_precision_init_val().view(-1) + master_weight_shard = high_precision_init_val.to(weight.device).float()[ + weight_start:weight_end + ] + else: + master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end] + self.master_weights.append(master_weight_shard) + else: + self.local_weights.append(None) + self.master_weights.append(None) + setattr( + weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda") + ) + + def _flatten_tensors_with_pad(self, tensors): + """ + Flatten the list of tensors and pad them to align with the world size. + + Args: + tensors (list): List of tensors to flatten. + + Returns: + tuple: Flattened tensor and its original length before padding. + """ + world_size = dist.get_world_size(self.dp_group) + + flatten_tensor = torch.cat(tensors) + original_length = flatten_tensor.size(0) + + padding_needed = (world_size - original_length % world_size) % world_size + if padding_needed > 0: + zeros = torch.zeros(padding_needed, dtype=flatten_tensor.dtype, device="cuda") + flatten_tensor = torch.cat([flatten_tensor, zeros]) + + return flatten_tensor, original_length + + def zero_grad(self): + for weight in self.weights: + weight.grad = None + weight.main_grad.zero_() + + def step(self): + """ + Perform an optimization step for the distributed sharded model. + + This method includes: + 1. Gradient reduce-scatter: Synchronize gradients across all processes. + 2. Master weight update: Update high-precision master weights using local gradients. + 3. Precision casting: Cast updated master weights to FP8 or BF16 precision. + 4. Weight synchronization: All-gather updated weights across all processes. + + Returns: + None + """ + # Step 1: Reduce-scatter the gradients + main_grad_buffer, _ = self._flatten_tensors_with_pad( + [weight.main_grad.view(-1) for weight in self.weights] + ) + dist.reduce_scatter_tensor( + self.local_main_grad_shard, main_grad_buffer, group=self.dp_group + ) + self.local_main_grad_shard /= dist.get_world_size(self.dp_group) + + # Step 2: Update the master weights + for weight, master_weight, (shard_start, shard_end) in zip( + self.weights, self.master_weights, self.shard_indices + ): + if master_weight is None: + continue + + # Extract the local gradient shard for this weight + grad = self.local_main_grad_shard[shard_start:shard_end] + + # Update the master weight using gradient descent + master_weight -= grad * self.lr + + # Step 3: Cast master weights to FP8 or BF16 precision + if isinstance(self.weights[0], QuantizedTensor): + local_weights = [] + for local_weight in self.local_weights: + if local_weight is None: + local_weights.append(None) + continue + + local_weights.append(local_weight) + + cast_master_weights_to_fp8( + self.weights, + self.master_weights, + [idx[0] for idx in self.weight_indices], + self.dp_group, + local_weights, + manual_post_all_gather_processing=self.manual_post_all_gather_processing, + ) + else: + for weight, master_weight in zip(self.local_weights, self.master_weights): + if master_weight is None: + continue -TEST_ROOT = Path(__file__).parent.resolve() -NUM_PROCS: int = min(2, torch.cuda.device_count()) -LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + # Copy updated master weights to local weights + weight.data.copy_(master_weight) + + # Step 4: All-gather updated weights across processes + dist.all_gather_into_tensor( + self.flatten_weight, self.local_weight_shard, group=self.dp_group + ) + + if self.manual_post_all_gather_processing: + quantized_weights = [ + weight for weight in self.weights if isinstance(weight, QuantizedTensor) + ] + post_all_gather_processing(quantized_weights) + + +def _test_mini_optimizer(dp_group): + """Make sure the implementation of MiniZero_1 and MiniFSDP is correct""" + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + weights = [ + torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"), + torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"), + ] + + weights_1 = weights + weights_2 = [weight.clone() for weight in weights] + weights_3 = [weight.clone() for weight in weights] + + lr = 1.0 + optimizer_1 = MiniZero_1(weights_1, lr, dp_group) + optimizer_2 = MiniOptimizer(weights_2, lr, dp_group) + optimizer_3 = MiniFSDP(weights_3, lr, dp_group) + + for _ in range(100): + for w1, w2, w3 in zip(weights_1, weights_2, weights_3): + main_grads = [ + torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the grads of different ranks are different. + main_grad = main_grads[rank] + w1.main_grad = main_grad + w2.main_grad = main_grad + w3.main_grad = main_grad + + optimizer_1.step() + optimizer_2.step() + optimizer_3.step() + + for w1, w2 in zip(weights_1, weights_2): + torch.testing.assert_close(w1, w2, atol=0, rtol=0) + for w1, w3 in zip(weights_1, weights_3): + torch.testing.assert_close(w1, w3, atol=0, rtol=0) + + +def _test_cast_master_weights_to_fp8(quantization, dp_group, manual_post_all_gather_processing): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + torch.manual_seed(12345) + torch.cuda.manual_seed(12345) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True} + + # Create model with FP8 weights + with te.quantized_model_init( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + # Allocate main_grads for each weight + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda") + w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda") + + optimizer_fp8 = MiniZero_1( + [w for w in model_fp8.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group) + + for i in range(100): + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + w_fp8.main_grad.zero_() + w.main_grad.zero_() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + +def _test_fsdp_cast_master_weights_to_fp8( + quantization, dp_group, manual_post_all_gather_processing +): + rank = dist.get_rank(dp_group) + world_size = dist.get_world_size(dp_group) + + # Configuration constants + NUM_STEPS = 100 + SEED = 12345 + + torch.manual_seed(SEED) + torch.cuda.manual_seed(SEED) + + mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)] + mock_group = mock_groups[rank] + + linear_kwargs = { + "params_dtype": torch.bfloat16, + "bias": False, + "fuse_wgrad_accumulation": True, + } + + # Create model with FP8 weights + with te.quantized_model_init( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + preserve_high_precision_init_val=True, + ): + model_fp8 = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Create model with BF16 weights + model = nn.Sequential( + te.Linear(128, 256 + 16, **linear_kwargs), + te.Linear(256 + 16, 256 * 3, **linear_kwargs), + te.Linear(256 * 3, 128, **linear_kwargs), + ) + + # Make sure the BF16 model and FP8 model have the same initial weights + for w_fp8, w in zip(model_fp8.parameters(), model.parameters()): + high_precision_init_val = w_fp8.get_high_precision_init_val() + w.data.copy_(high_precision_init_val) + + optimizer_fp8 = MiniFSDP( + [w for w in model_fp8.parameters()], 10.0, dp_group, manual_post_all_gather_processing + ) + optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group) + + for _ in range(100): + optimizer_fp8.zero_grad() + optimizer.zero_grad() + + inputs = [ + torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size) + ] + # Choose based on rank to make sure the inputs of different ranks are different. + x = inputs[rank] + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y_fp8 = model_fp8(x) + + with te.autocast( + enabled=quantization is not None, + recipe=_get_quantization_recipe(quantization), + amax_reduction_group=mock_group, + ): + y = model(x) + + targets = [torch.randn_like(y) for _ in range(world_size)] + # Choose based on rank to make sure the targets of different ranks are different. + target = targets[rank] + loss_fp8 = nn.MSELoss()(y_fp8, target) + loss = nn.MSELoss()(y, target) + + loss_fp8.backward() + loss.backward() + + optimizer_fp8.step() + optimizer.step() + + torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0) + + +def run_parallel_tests() -> None: + """Run parallel tests""" + + WORLD_RANK = int(os.getenv("RANK", "0")) + WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1")) + LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0")) + LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1")) + + assert WORLD_SIZE == LOCAL_SIZE # this test supports only 1 node + assert LOCAL_SIZE <= torch.cuda.device_count() + dist_init_kwargs = { + "backend": "nccl", + "rank": WORLD_RANK, + "world_size": WORLD_SIZE, + "timeout": datetime.timedelta(seconds=30), + } + dist_init_kwargs["init_method"] = "env://" + dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}") + assert dist.is_nccl_available() + torch.cuda.set_device(LOCAL_RANK) + dist.init_process_group(**dist_init_kwargs) + dp_group = dist.new_group(backend="nccl") + + quantizations = [] + if is_fp8_available(): + quantizations.extend(["fp8", "fp8_cs"]) + if is_fp8_block_scaling_available(): + quantizations.append("fp8_block") + + manual_post_all_gather_processings = [False, True] + + _test_mini_optimizer(dp_group) + + for quantization in quantizations: + for post_ag_processing in manual_post_all_gather_processings: + _test_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group, post_ag_processing) + + dist.destroy_process_group() + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, reason="cast_master_weights_to_fp8 test needs at least 2 GPUs." +) +@pytest.mark.parametrize("world_size", [2]) +def test_cast_master_weights_to_fp8(world_size: int) -> None: + """Launch parallel job that runs parallel tests""" + python_exe = pathlib.Path(sys.executable).resolve() + current_file = pathlib.Path(__file__).resolve() + command = [ + python_exe, + "-m", + "torch.distributed.run", + f"--nproc_per_node={world_size}", + current_file, + "--parallel", + ] + result = subprocess.run( + command, + check=True, + ) -def _run_test(quantization): - test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py" - test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization] - result = subprocess.run(test_cmd, env=os.environ, check=False) - assert result.returncode == 0 +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", action="store_true", help="Run parallel tests") + args = parser.parse_args() + if args.parallel: + run_parallel_tests() -@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs", "fp8_block"]) -def test_cast_master_weights_to_fp8(quantization): - if quantization in ("fp8", "fp8_cs") and not fp8_available: - pytest.skip(reason_for_no_fp8) - if quantization == "fp8_block" and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - _run_test(quantization) +if __name__ == "__main__": + main() diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 74d1dc69c..ddb31c30f 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -9,13 +9,12 @@ import torch import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager if torch.cuda.device_count() < 2: pytest.skip("Comm+GEMM overlap requires at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) RNG_SEED: int = 42 SEQ_LENGTH: int = 1024 diff --git a/tests/pytorch/distributed/test_fusible_ops.py b/tests/pytorch/distributed/test_fusible_ops.py index 6dc17b126..b9fbfb2a5 100644 --- a/tests/pytorch/distributed/test_fusible_ops.py +++ b/tests/pytorch/distributed/test_fusible_ops.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -22,31 +22,35 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( + QuantizedTensor, Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, + is_bf16_available, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.utils import is_bf16_compatible, is_fp8_fnuz +from transformer_engine.pytorch.utils import is_fp8_fnuz import transformer_engine_torch as tex # Import utility functions _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) -from utils import dtype_tols, make_recipe +from utils import dtype_tols, make_recipe, quantization_tols # Check what quantization schemes are supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_mxfp8_available(return_reason=True) quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: quantization_list.append("mxfp8") +if nvfp4_available: + quantization_list.append("nvfp4") @functools.cache @@ -117,6 +121,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -417,7 +429,7 @@ def _test_basic_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -430,7 +442,7 @@ def _test_basic_linear( with torch.no_grad(): op.weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -439,7 +451,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -583,7 +595,7 @@ def _test_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -602,7 +614,7 @@ def _test_linear( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -611,7 +623,7 @@ def _test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -625,6 +637,204 @@ def _test_linear( torch.testing.assert_close(db_test, db_ref, **tols) +def _test_mlp( + *, + bias: bool = True, + hidden_size: int = 32, + local_batch_size: int = 32, + dtype: torch.dtype = torch.float32, + device: torch.device = "cuda", + quantization: Optional[str] = None, + quantized_weight: bool = False, + sequence_parallel: bool = False, +) -> None: + """2-layer MLP + + MLP includes GELU activation in order to test op fusions. Model + performs warmup steps in order to test inter-step logic. + + """ + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and quantized_weight: + return + + # Distributed process group + process_group = world_group() + rank = torch.distributed.get_rank(process_group) + world_size = torch.distributed.get_world_size(process_group) + + # Tensor dimensions + mlp_size = hidden_size * world_size + batch_size = local_batch_size + if sequence_parallel: + batch_size *= world_size + in_shape = (batch_size, hidden_size) + + # Random data + reset_rng() + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w1_ref, w1_test = make_reference_and_test_tensors( + (mlp_size, hidden_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b1_ref, b1_test = None, None + w2_ref, w2_test = make_reference_and_test_tensors( + (hidden_size, mlp_size), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = None, None + if bias: + b1_ref, b1_test = make_reference_and_test_tensors( + (mlp_size,), + test_dtype=dtype, + test_device=device, + ) + b2_ref, b2_test = make_reference_and_test_tensors( + (world_size, hidden_size), + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w1_ref) + if bias: + y_ref += b1_ref + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref = torch.nn.functional.linear(y_ref, w2_ref) + if bias: + y_ref += b2_ref.sum(dim=0) + y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh") + y_ref.backward(dy_ref) + + # Convert to distributed tensors + with torch.no_grad(): + local_mlp_size = mlp_size // world_size + local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size) + dx_ref = x_ref.grad + dw1_ref = w1_ref.grad[local_mlp_slice, :] + w1_ref = w1_ref[local_mlp_slice, :] + w1_test = w1_test[local_mlp_slice, :] + dw2_ref = w2_ref.grad[:, local_mlp_slice] + w2_ref = w2_ref[:, local_mlp_slice] + w2_test = w2_test[:, local_mlp_slice] + if bias: + db1_ref = b1_ref.grad[local_mlp_slice] + b1_ref = b1_ref[local_mlp_slice] + b1_test = b1_test[local_mlp_slice] + db2_ref = b2_ref.grad[rank, :] + b2_ref = b2_ref[rank, :] + b2_test = b2_test[rank, :] + else: + db1_ref = None + db2_ref = None + if sequence_parallel: + local_batch_slice = slice( + rank * local_batch_size, + (rank + 1) * local_batch_size, + ) + x_ref = x_ref[local_batch_slice, ...] + dx_ref = dx_ref[local_batch_slice, ...] + x_test = x_test[local_batch_slice, ...].clone() + y_ref = y_ref[local_batch_slice, ...] + dy_ref = dy_ref[local_batch_slice, ...] + dy_test = dy_test[local_batch_slice, ...].clone() + x_test.requires_grad_() + + # Implementation with fusible operation + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + model = te_ops.Sequential( + te_ops.GELU(), + te_ops.Linear( + hidden_size, + mlp_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="column", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + te_ops.Linear( + mlp_size, + hidden_size, + bias=bias, + device=device, + dtype=dtype, + tensor_parallel_mode="row", + tensor_parallel_group=process_group, + sequence_parallel=sequence_parallel, + ), + te_ops.GELU(), + ) + with torch.no_grad(): + model[1].weight.copy_(w1_test) + model[3].weight.copy_(w2_test) + if bias: + model[1].bias.copy_(b1_test) + model[3].bias.copy_(b2_test) + del w1_test, w2_test, b1_test, b2_test + + # Warmup steps + for _ in range(3): + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + x_test.grad = None + model[1].weight.grad = None + model[3].weight.grad = None + if bias: + model[1].bias.grad = None + model[3].bias.grad = None + + # Forward and backward step + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = model(x_test) + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, dx_ref, **tols) + torch.testing.assert_close(dw1_test, dw1_ref, **tols) + torch.testing.assert_close(dw2_test, dw2_ref, **tols) + if bias: + db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") + db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db1_test, db1_ref, **tols) + torch.testing.assert_close(db2_test, db2_ref, **tols) + + def _test_fp8_scale_update( *, amax_history_len: int = 31, @@ -736,7 +946,7 @@ def ref_amax_and_scale( amax_history_len=amax_history_len, amax_compute_algo=amax_compute_algo, ) - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -791,16 +1001,31 @@ def run_parallel_tests() -> None: for config in itertools.product( quantization_list, ("column", "row"), + (False, True), ): if rank == 0: print(f"Running _test_linear with {config=}") - quantization, tensor_parallel_mode = config - dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32 + quantization, tensor_parallel_mode, sequence_parallel = config + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 _test_linear( bias=True, # bias=False is tested in _test_basic_linear dtype=dtype, quantization=quantization, tensor_parallel_mode=tensor_parallel_mode, + sequence_parallel=sequence_parallel, + ) + + # MLP + for config in itertools.product(quantization_list, (False, True)): + if rank == 0: + print(f"Running _test_mlp with {config=}") + quantization, sequence_parallel = config + dtype = torch.bfloat16 if is_bf16_available() else torch.float32 + _test_mlp( + bias=True, # bias=False is tested in _test_basic_linear + dtype=dtype, + quantization=quantization, + sequence_parallel=sequence_parallel, ) # FP8 scale update diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index d6ddfe27c..61c813b8f 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -16,23 +16,24 @@ import pytest import torch +from typing import Optional, Iterable + import transformer_engine import transformer_engine.pytorch as te import transformer_engine.pytorch.cpp_extensions as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( UserbuffersBackwardLinear, UserbuffersForwardLinear, ) -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, + QuantizedTensor, + Float8Tensor, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.utils import is_bf16_compatible + # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -40,8 +41,8 @@ from utils import dtype_tols, make_recipe, str_to_dtype # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) quantization_list: list[Optional[str]] = [None] if fp8_available: quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) @@ -301,7 +302,7 @@ def _test_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): + with te.quantized_model_init(enabled=quantized_compute, recipe=recipe): ops = [] linear_op = None bias_op = None @@ -351,7 +352,7 @@ def _test_linear( bias_op.bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) diff --git a/tests/pytorch/distributed/test_numerics.py b/tests/pytorch/distributed/test_numerics.py index 1ff5aff99..97a69e779 100644 --- a/tests/pytorch/distributed/test_numerics.py +++ b/tests/pytorch/distributed/test_numerics.py @@ -8,7 +8,7 @@ import pytest import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te """ Distributed numerics tests @@ -26,11 +26,12 @@ if torch.cuda.device_count() < 2: pytest.skip("Distributed training needs at least 2 GPUs.") -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) TEST_ROOT = Path(__file__).parent.resolve() NUM_PROCS: int = min(4, torch.cuda.device_count()) @@ -51,7 +52,9 @@ def _run_test(quantization): all_boolean = [True, False] -@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"]) +@pytest.mark.parametrize( + "quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling", "nvfp4"] +) def test_distributed(quantization): if quantization == "fp8" and not fp8_available: pytest.skip(reason_for_no_fp8) @@ -61,4 +64,6 @@ def test_distributed(quantization): pytest.skip(reason_for_no_mxfp8) if quantization == "fp8_block_scaling" and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) _run_test(quantization) diff --git a/tests/pytorch/distributed/test_numerics_exact.py b/tests/pytorch/distributed/test_numerics_exact.py new file mode 100644 index 000000000..72aa78664 --- /dev/null +++ b/tests/pytorch/distributed/test_numerics_exact.py @@ -0,0 +1,70 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import subprocess +from pathlib import Path + +import pytest +import torch +import transformer_engine.pytorch as te + +""" + Distributed numerics tests + + This numerical test aims for zero tolerance test for absolute confidence in numerics. + In the case of NVFP4, with the custom NVFP4 quantization, we matched bitwise + result with the native silicon. For distrbuted test cases, we can do the same by thing + by comparing BF16 AG results with the low precision AG results at layer level. +""" + + +if torch.cuda.device_count() < 2: + pytest.skip("Distributed training needs at least 2 GPUs.") + +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +TEST_ROOT = Path(__file__).parent.resolve() +NUM_PROCS: int = min(4, torch.cuda.device_count()) +LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"] + + +def _run_test(quantization, batch_size, hidden_size, out_size): + test_path = TEST_ROOT / "run_numerics_exact.py" + test_cmd = LAUNCH_CMD + [str(test_path)] + + test_cmd += ["--quantization", quantization] + test_cmd += ["--batch-size", str(batch_size)] + test_cmd += ["--hidden-size", str(hidden_size)] + test_cmd += ["--out-size", str(out_size)] + + result = subprocess.run(test_cmd, env=os.environ, check=False) + assert result.returncode == 0 + + +all_boolean = [True, False] + + +@pytest.mark.parametrize("quantization", ["nvfp4"]) +@pytest.mark.parametrize( + "batch_size, hidden_size, out_size", + [ + (64, 128, 128), + (128, 128, 128), + (128, 256, 256), + (512, 1024, 768), + (512, 256, 1024), + (2048, 2048, 2048), + ], +) +def test_distributed(quantization, batch_size, hidden_size, out_size): + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + + _run_test(quantization, batch_size, hidden_size, out_size) diff --git a/tests/pytorch/distributed/test_sanity.py b/tests/pytorch/distributed/test_sanity.py index 39494a92b..fbbbe2997 100644 --- a/tests/pytorch/distributed/test_sanity.py +++ b/tests/pytorch/distributed/test_sanity.py @@ -7,8 +7,7 @@ import pytest import torch import transformer_engine -from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention -from transformer_engine.pytorch import TransformerLayer, Linear +from transformer_engine.pytorch import DotProductAttention, TransformerLayer, Linear _current_file = pathlib.Path(__file__).resolve() sys.path.append(str(_current_file.parent.parent)) diff --git a/tests/pytorch/distributed/test_torch_fsdp2.py b/tests/pytorch/distributed/test_torch_fsdp2.py index f5c186a3b..91d6fc6ed 100644 --- a/tests/pytorch/distributed/test_torch_fsdp2.py +++ b/tests/pytorch/distributed/test_torch_fsdp2.py @@ -6,47 +6,54 @@ import pytest import subprocess from pathlib import Path -from transformer_engine.pytorch import torch_version -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te import torch -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() - +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) NUM_PROCS: int = torch.cuda.device_count() -def _run_test(fp_init, sharding_dims): +def _run_test(fp_init, sharding_dims, recipe, layer_type): test_path = Path(__file__).parent.resolve() / "run_fsdp2_model.py" test_cmd = ["torchrun", f"--nproc_per_node={NUM_PROCS}", str(test_path)] if fp_init: test_cmd += ["--fp8-init"] + if len(sharding_dims) == 1: test_cmd += ["--sharding-dims", str(sharding_dims[0])] elif len(sharding_dims) == 2: test_cmd += ["--sharding-dims", str(sharding_dims[0]), str(sharding_dims[1])] else: assert False + test_cmd += ["--recipe", recipe] + test_cmd += ["--layer-type", layer_type] + result = subprocess.run(test_cmd, env=os.environ, check=True) @pytest.mark.skipif(NUM_PROCS < 4, reason="Requires 4+ GPUs") @pytest.mark.skipif(NUM_PROCS % 2 != 0, reason="Requires even number of GPUs") -@pytest.mark.skipif(not torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") +@pytest.mark.skipif(not te.torch_version() >= (2, 4, 0), reason="Requires PyTorch 2.4.0+") @pytest.mark.parametrize("sharding_dims", ([NUM_PROCS], [2, NUM_PROCS // 2])) @pytest.mark.parametrize("fp8_init", (False, True)) -def test_distributed(fp8_init, sharding_dims): +@pytest.mark.parametrize("recipe", ("delayed_scaling", "current_scaling", "mx_fp8_block_scaling")) +@pytest.mark.parametrize("layer_type", ("LayerNormLinear", "TransformerLayer")) +def test_distributed(fp8_init, sharding_dims, recipe, layer_type): # Skip invalid configurations if torch.cuda.device_count() < 4: pytest.skip("FSDP2 test requires at least 4 GPUs") - if fp8_init and not fp8_available: + if recipe == "mx_fp8_block_scaling" and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + elif not fp8_available: pytest.skip(reason_for_no_fp8) - _run_test(fp8_init, sharding_dims) + _run_test(fp8_init, sharding_dims, recipe, layer_type) def test_dummy() -> None: diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py new file mode 100644 index 000000000..6009643ff --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -0,0 +1,242 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def check_nvfp4_gemm_versus_reference( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, + accumulate: bool, + *, + x_columnwise: bool = False, + w_columnwise: bool = False, +): + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input tensors + x_shape = (K, M) if x_columnwise else (M, K) + w_shape = (K, N) if w_columnwise else (N, K) + x = torch.randn(x_shape, dtype=x_dtype, device=device) + w = torch.randn(w_shape, dtype=w_dtype, device=device) + + # Setup out tensor if accumulate is True + if accumulate: + out = torch.randn((M, N), dtype=out_dtype, device=device) + else: + out = None + + # Native TE NVFP4 quantization + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + # Quantize x and w + x_nvfp4_native = x_quantizer.make_empty( + x_shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_native = x_quantizer.update_quantized(x, x_nvfp4_native) + w_nvfp4_native = w_quantizer.make_empty( + w_shape, dtype=w_dtype, device=device, requires_grad=False + ) + w_nvfp4_native = w_quantizer.update_quantized(w, w_nvfp4_native) + + # Extract quantized data from native NVFP4Tensors + qx_data = ( + x_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if x_columnwise + else x_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + qw_data = ( + w_nvfp4_native._columnwise_data.view(dtype=torch.uint8) + if w_columnwise + else w_nvfp4_native._rowwise_data.view(dtype=torch.uint8) + ) + sx_native = ( + x_nvfp4_native._columnwise_scale_inv if x_columnwise else x_nvfp4_native._rowwise_scale_inv + ) + sw_native = ( + w_nvfp4_native._columnwise_scale_inv if w_columnwise else w_nvfp4_native._rowwise_scale_inv + ) + + # Trim quantized data to match the actual tensor dimensions (remove padding) + qx_data = qx_data[:M, :] + qw_data = qw_data[:N, :] + + # NVFP4 uses 16-element blocks, trim scales to remove padding + block_length = 16 # NVFP4 uses 16-element blocks + expected_sx_cols = expected_sw_cols = K // block_length + # Trim the scales to remove padding + sx_trimmed = sx_native[:M, :expected_sx_cols] + sw_trimmed = sw_native[:N, :expected_sw_cols] + + # Native scales are stored as uint8 but need to be interpreted as float8_e4m3fn + # for the reference GEMM to work correctly + sx_trimmed = sx_trimmed.view(torch.float8_e4m3fn) + sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) + + # Create reference quantizer for reference GEMM + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=True, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + + # Create reference quantized tensors needed by reference GEMM + x_nvfp4_ref = ref_quantizer.quantize(x) + w_nvfp4_ref = ref_quantizer.quantize(w) + + # Reference GEMM using quantizer's qgemm method + y_ref = ref_quantizer.qgemm( + qx=qx_data, + qw=qw_data, + m_params=None, # MMParams not used in reference + out_dtype=out_dtype, + sx=sx_trimmed, + sw=sw_trimmed, + bias=None, # No bias for this test + out=out.clone() if accumulate else None, + accumulate=accumulate, + gemm_type=None, # GEMMType not used in reference + qresult_x=x_nvfp4_ref, + qresult_w=w_nvfp4_ref, + ) + + # Native TE GEMM using tex.generic_gemm (cuBLAS GEMM) + # Allocate cuBLAS workspace + workspace = torch.empty(4, dtype=torch.uint8, device=device) + + transa = True if not w_columnwise else False + transb = False if not x_columnwise else True + out_quantizer = None + bias = None + bias_dtype = TE_DType[torch.bfloat16] + use_gelu = False + gelu_input = None + use_grad = False + use_split_accumulator = False + + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] + + # just in case of accumulation, make sure y_ref and y_native are not the same tensor + assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" + # Reset nans to zeros because torch.assert_close does not assume nans to be equal + assert not torch.isnan(y_ref.float()).all(), "All elements are nan" + y_ref = torch.where(y_ref.isnan(), torch.zeros_like(y_ref), y_ref) + y_native = torch.where(y_native.isnan(), torch.zeros_like(y_native), y_native) + + # Compare results with some tolerance + torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("accumulate", [True, False], ids=["accumulate", "no_accumulate"]) +@pytest.mark.parametrize( + "is_x_columnwise, is_w_columnwise", + [ + (False, False), # Only rowwise x rowwise is supported by reference GEMM + # Note: Reference GEMM expects inputs as (M,K) x (N,K) with rowwise quantization + # Columnwise layouts are not supported by the reference implementation + ], + ids=["rowxrow"], +) +def test_nvfp4_gemm_versus_reference( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + accumulate: bool, + is_x_columnwise: bool, + is_w_columnwise: bool, +): + check_nvfp4_gemm_versus_reference( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, + accumulate=accumulate, + x_columnwise=is_x_columnwise, + w_columnwise=is_w_columnwise, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py new file mode 100644 index 000000000..0292063ab --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -0,0 +1,583 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import utils + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +class GetRecipes: + @staticmethod + def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(random_hadamard_transform=True) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(random_hadamard_transform=False) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(random_hadamard_transform=True) + return nvfp4_recipe + + @staticmethod + def nvfp4_2d_quantization_only(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams(fp4_2d_quantization=False) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams(fp4_2d_quantization=False) + return nvfp4_recipe + + @staticmethod + def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + @staticmethod + def nvfp4_recipe_to_test(with_rht: bool = False, with_2d_quantization: bool = False): + if with_rht and with_2d_quantization: + return GetRecipes.nvfp4_rht_and_2d_quantization() + elif with_rht: + return GetRecipes.nvfp4_rht_only() + elif with_2d_quantization: + return GetRecipes.nvfp4_2d_quantization_only() + else: + return GetRecipes.nvfp4_vanilla() + + +def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bool = False): + """ + Create a quantizer factory for NVFP4 reference implementation. + + This factory returns NVFP4QuantizerRef instances based on the role and configuration. + Used with CustomRecipe to create reference quantizers. + + Args: + with_rht: Whether to enable random Hadamard transform + with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights) + + Returns: + A factory function that takes a role string and returns a quantizer instance + """ + + def factory(role): + if role == "linear_input": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=with_rht, + ) + elif role == "linear_weight": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), + pow_2_scales=False, + with_rht=False, + ) + elif role == "linear_output": + # Output quantization not used + return None + elif role == "linear_grad_output": + return quantization_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=with_rht, + ) + elif role == "linear_grad_input": + # Grad input quantization not used + return None + else: + # For any other roles, return None + return None + + return factory + + +def reset_rng_states(): + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + +def check_nvfp4_module_versus_reference( + module_class, + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 module against reference implementation. + + Args: + module_class: te.Linear or te.LayerNormLinear + in_features: Input feature dimension + out_features: Output feature dimension + bias: Whether to use bias + x_dtype: Input tensor dtype + num_steps: Number of forward/backward steps to test + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + reset_rng_states() + + # Create native module + print("\nCreate native module") + if module_class == te.Linear: + native_module = te.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.LayerNormLinear: + native_module = te.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + else: + raise ValueError(f"Unsupported module class: {module_class}") + + # Create reference module with same weights + reset_rng_states() + + # Create reference module + print("Create reference module") + if module_class == te.Linear: + ref_module = te.Linear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + elif module_class == te.LayerNormLinear: + ref_module = te.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + ) + + # Sync weights between native and reference modules + with torch.no_grad(): + # Copy main weight and bias parameters + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + + # Copy layer norm parameters if they exist + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + # Create recipes for native and reference implementations + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization) + nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory) + + # Training loop comparison + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + with te.autocast(enabled=True, recipe=nvfp4_recipe): + # enable weight cache by giving is_first_microbatch + y_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + with te.autocast(enabled=True, recipe=nvfp4_ref_recipe): + y_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + # Store results + native_outputs.append( + { + "output": y_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results across all steps + for step in range(num_steps): + native_out = native_outputs[step] + ref_out = ref_outputs[step] + + # Compare outputs + torch.testing.assert_close( + native_out["output"], + ref_out["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + + # Compare input gradients + torch.testing.assert_close( + native_out["input_grad"], + ref_out["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Input gradient mismatch at step {step}. Native: {native_out['input_grad']}, Ref:" + f" {ref_out['input_grad']}" + ), + ) + + # Compare weight gradients + torch.testing.assert_close( + native_out["weight_grad"], + ref_out["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=( + f"Weight gradient mismatch at step {step}. Native: {native_out['weight_grad']}," + f" Ref: {ref_out['weight_grad']}" + ), + ) + + # Compare bias gradients + if bias and native_out["bias_grad"] is not None and ref_out["bias_grad"] is not None: + torch.testing.assert_close( + native_out["bias_grad"], + ref_out["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + (512, 512), + (768, 3072), + (1024, 4096), + ], +) +# @pytest.mark.parametrize("bias", [True, False], ids=["with_bias", "no_bias"]) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1, 3], ids=["single_step", "multi_step"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + """Test NVFP4 Linear module against reference implementation.""" + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_module_versus_reference( + module_class=te.Linear, + in_features=in_features, + out_features=out_features, + bias=bias, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) + + +def check_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int = 1, + with_rht: bool = False, + with_2d_quantization: bool = False, +): + """ + Compare native NVFP4 LayerNormLinear module against reference implementation, + including ln_out. + """ + device = "cuda" + batch_size = 32 + seq_len = 128 + + # Create both modules with identical initialization + reset_rng_states() + + # Native module + native_module = te.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Reference module + reset_rng_states() + ref_module = te.LayerNormLinear( + in_features=in_features, + out_features=out_features, + bias=bias, + device=device, + params_dtype=x_dtype, + normalization=normalization, + return_layernorm_output=True, + ) + + # Sync weights and LN params + with torch.no_grad(): + if hasattr(native_module, "weight") and hasattr(ref_module, "weight"): + ref_module.weight.copy_(native_module.weight) + if bias and hasattr(native_module, "bias") and hasattr(ref_module, "bias"): + ref_module.bias.copy_(native_module.bias) + if hasattr(native_module, "layer_norm_weight") and hasattr(ref_module, "layer_norm_weight"): + if ( + native_module.layer_norm_weight is not None + and ref_module.layer_norm_weight is not None + ): + ref_module.layer_norm_weight.copy_(native_module.layer_norm_weight) + if hasattr(native_module, "layer_norm_bias") and hasattr(ref_module, "layer_norm_bias"): + if native_module.layer_norm_bias is not None and ref_module.layer_norm_bias is not None: + ref_module.layer_norm_bias.copy_(native_module.layer_norm_bias) + + # Create recipes for native and reference implementations + nvfp4_recipe = GetRecipes.nvfp4_recipe_to_test(with_rht, with_2d_quantization) + nvfp4_ref_factory = get_nvfp4_quantizer_factory(with_rht, with_2d_quantization) + nvfp4_ref_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_factory) + + native_outputs = [] + ref_outputs = [] + + for step in range(num_steps): + torch.manual_seed(1234 + step) + torch.cuda.manual_seed(1234 + step) + + x_shape = (batch_size, seq_len, in_features) + x_val = torch.normal(mean=0.0, std=1.0, size=x_shape, dtype=x_dtype, device=device) + x_native = x_val.clone().detach().requires_grad_(True) + x_ref = x_native.clone().detach().requires_grad_(True) + + grad_output_shape = (batch_size, seq_len, out_features) + grad_output_val = torch.normal( + mean=0.0, std=1.0, size=grad_output_shape, dtype=x_dtype, device=device + ) + grad_output = grad_output_val.clone().detach() + + # Native forward/backward + with te.autocast(enabled=True, recipe=nvfp4_recipe): + y_native, ln_out_native = native_module(x_native, is_first_microbatch=(step == 0)) + y_native.backward(grad_output) + + # Reference forward/backward + with te.autocast(enabled=True, recipe=nvfp4_ref_recipe): + y_ref, ln_out_ref = ref_module(x_ref) + y_ref.backward(grad_output) + + native_outputs.append( + { + "output": y_native.detach().clone(), + "ln_out": ln_out_native.detach().clone(), + "input_grad": ( + x_native.grad.detach().clone() if x_native.grad is not None else None + ), + "weight_grad": ( + native_module.weight.grad.detach().clone() + if native_module.weight.grad is not None + else None + ), + "bias_grad": ( + native_module.bias.grad.detach().clone() + if bias and native_module.bias.grad is not None + else None + ), + } + ) + ref_outputs.append( + { + "output": y_ref.detach().clone(), + "ln_out": ln_out_ref.detach().clone(), + "input_grad": (x_ref.grad.detach().clone() if x_ref.grad is not None else None), + "weight_grad": ( + ref_module.weight.grad.detach().clone() + if ref_module.weight.grad is not None + else None + ), + "bias_grad": ( + ref_module.bias.grad.detach().clone() + if bias and ref_module.bias.grad is not None + else None + ), + } + ) + + # Compare results + for step in range(num_steps): + n = native_outputs[step] + r = ref_outputs[step] + torch.testing.assert_close( + n["output"], + r["output"], + atol=1e-6, + rtol=1e-6, + msg=f"Output mismatch at step {step}", + ) + torch.testing.assert_close( + n["ln_out"], + r["ln_out"], + atol=1e-6, + rtol=1e-6, + msg=f"LN output mismatch at step {step}", + ) + torch.testing.assert_close( + n["input_grad"], + r["input_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Input gradient mismatch at step {step}", + ) + torch.testing.assert_close( + n["weight_grad"], + r["weight_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Weight gradient mismatch at step {step}", + ) + if bias and n["bias_grad"] is not None and r["bias_grad"] is not None: + torch.testing.assert_close( + n["bias_grad"], + r["bias_grad"], + atol=1e-6, + rtol=1e-6, + msg=f"Bias gradient mismatch at step {step}", + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "in_features, out_features", + [ + (128, 256), + (256, 128), + ], +) +@pytest.mark.parametrize("bias", [False], ids=["no_bias"]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("num_steps", [1], ids=["single_step"]) +@pytest.mark.parametrize("normalization", ["LayerNorm", "RMSNorm"], ids=["LayerNorm", "RMSNorm"]) +@pytest.mark.parametrize("with_rht", [True, False], ids=["with_rht", "no_rht"]) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["with_2d_quantization", "no_2d_quantization"] +) +def test_nvfp4_layernorm_linear_versus_reference( + in_features: int, + out_features: int, + bias: bool, + normalization: str, + x_dtype: torch.dtype, + num_steps: int, + with_rht: bool, + with_2d_quantization: bool, +): + if with_rht and x_dtype != torch.bfloat16: + pytest.skip("RHT is only supported for bfloat16 input") + + check_nvfp4_layernorm_linear_versus_reference( + in_features=in_features, + out_features=out_features, + bias=bias, + normalization=normalization, + x_dtype=x_dtype, + num_steps=num_steps, + with_rht=with_rht, + with_2d_quantization=with_2d_quantization, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py new file mode 100644 index 000000000..2467c7e2e --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -0,0 +1,491 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.common.recipe import NVFP4BlockScaling +from transformer_engine.pytorch.constants import TE_DType + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + # Reference quantization + quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=quant_tile_shape, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + if x_nvfp4_ref.data_t is not None + else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # # largest tile + (8192, 8192), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize("swizzled_scale", [False], ids=["linear_scale"]) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] +) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + swizzled_scale: bool, + use_cpp_allocator: bool, + with_2d_quantization: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + return_transpose=return_transpose, + swizzled_scale=swizzled_scale, + use_cpp_allocator=use_cpp_allocator, + with_2d_quantization=with_2d_quantization, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (128, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("extrema_high", [False, True], ids=["zeros", "maxes"]) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_extrema_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + extrema_high: bool, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + if extrema_high: + x = torch.full((M, N), torch.finfo(x_dtype).max, dtype=x_dtype, device=device) + else: + x = torch.zeros((M, N), dtype=x_dtype, device=device) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (16, 128), + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_boundary_values( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + """ + Stress rounding/threshold behavior by placing values just below/above + many potential bin edges within each 16-element microblock. + Validates native vs reference byte-for-byte and scale parity. + """ + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 123 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Construct a single row with paired boundary values: v-eps, v+eps + # spanning a wide dynamic range to exercise clipping and multiple bins. + # Ensure even N and N is multiple of 16 for microblocks, which holds for 128. + base = torch.linspace(-12.0, 12.0, steps=N // 2, dtype=torch.float32, device=device) + eps = torch.full_like(base, 1e-3) + # Avoid zero eps for very small magnitudes + eps = torch.maximum(eps, 1e-4 * torch.ones_like(base)) + lower = base - eps + upper = base + eps + row = torch.empty(N, dtype=torch.float32, device=device) + row[0::2] = lower + row[1::2] = upper + x = row.unsqueeze(0).repeat(M, 1).to(dtype=x_dtype) + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + (M, N), dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim any padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + + device = "cuda" + seed = 17 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Start from a contiguous tensor, then make a non-contiguous view by transpose + x_base = torch.randn((M, N), dtype=x_dtype, device=device) + x_nc = x_base.t() # shape (N, M), non-contiguous + assert not x_nc.is_contiguous() + + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x_nc) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x_nc.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x_nc, x_nvfp4_sut) + + assert x_nvfp4_sut._rowwise_data is not None + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + ) + x_nvfp4_ref = ref_quantizer.quantize(x_nc) + + qx_ref = x_nvfp4_ref.data.view(dtype=torch.uint8) if x_nvfp4_ref.data is not None else None + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + qx_t_ref = ( + x_nvfp4_ref.data_t.view(dtype=torch.uint8) if x_nvfp4_ref.data_t is not None else None + ) + sx_t_ref = ( + x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None + ) + ref_amax = x_nvfp4_ref.global_amax_row + + # Quantized must match + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + + # Compare only valid portion of scales (trim padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py new file mode 100644 index 000000000..904dfc2ea --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -0,0 +1,248 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +# NOTE: This file is dependent on the success of test_nvfp4_quantize_exact.py. +# Separate to make sure all the functionalities are working as expected. +# Otherwise reference implementation will get messy. + +# Due to the structure of NVFP4Quantizer, we need to test the RHT functionality +# together with the quantization functionality. + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.constants import TE_DType +from transformer_engine.common.recipe import NVFP4BlockScaling + +import pytest +import torch + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + contiguous: bool, + return_transpose: bool, + use_cpp_allocator: bool, + swizzled_scale: bool = False, + hadamard_dimension: int = 16, + with_rht: bool = True, + with_post_rht_amax: bool = True, + with_random_sign_mask: bool = True, +) -> None: + assert with_rht and with_post_rht_amax, "RHT and post-RHT amax reduction must be enabled." + + te_dtype = tex.DType.kFloat4E2M1 + + # Setup device and random seed + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + # Input + x = torch.randn((M, N), dtype=x_dtype, device=device) + + x = x.transpose(0, 1) if not contiguous else x + + # Quantize + nvfp4_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=return_transpose, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=with_rht, + with_post_rht_amax=with_post_rht_amax, + with_random_sign_mask=with_random_sign_mask, + ) + if use_cpp_allocator: + x_nvfp4_sut = nvfp4_quantizer(x) + else: + x_nvfp4_sut = nvfp4_quantizer.make_empty( + x.shape, dtype=x_dtype, device=device, requires_grad=False + ) + x_nvfp4_sut = nvfp4_quantizer.update_quantized(x, x_nvfp4_sut) + + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + qx_t = ( + x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + if x_nvfp4_sut._columnwise_data is not None + else None + ) + sx_t = x_nvfp4_sut._columnwise_scale_inv + amax_rowwise = x_nvfp4_sut._amax_rowwise + amax_colwise = x_nvfp4_sut._amax_columnwise + + qx = unpack_fp4(qx) + qx_t = unpack_fp4(qx_t) if qx_t is not None else None + + # Reference quantization using NVFP4QuantizerRef with built-in RHT + ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=return_transpose, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=with_rht, + with_random_sign_mask=with_random_sign_mask, + ) + x_nvfp4_ref = ref_quantizer.quantize(x) + # Extract data from RefNVFP4Tensor + qx_ref = ( + unpack_fp4(x_nvfp4_ref.data.view(dtype=torch.uint8)) + if x_nvfp4_ref.data is not None + else None + ) + sx_ref = x_nvfp4_ref.scale.view(dtype=torch.uint8) if x_nvfp4_ref.scale is not None else None + ref_amax_rowwise = x_nvfp4_ref.global_amax_row + + if return_transpose: + assert x_nvfp4_ref.data_t is not None + assert x_nvfp4_ref.scale_t is not None + qx_t_ref = unpack_fp4(x_nvfp4_ref.data_t.view(dtype=torch.uint8)) + sx_t_ref = x_nvfp4_ref.scale_t.view(dtype=torch.uint8) + # Compute transpose amax using the same reference quantizer + x_t_for_amax = ( + ref_quantizer._apply_rht(x.t().contiguous()) if with_rht else x.t().contiguous() + ) + ref_amax_colwise_t = torch.max(torch.abs(x_t_for_amax)).to(torch.float32).view(1) + else: + qx_t_ref = None + sx_t_ref = None + ref_amax_colwise_t = None + + torch.testing.assert_close(amax_rowwise, ref_amax_rowwise, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) + # Compare only the valid portion of scale tensors (reference may not have padding) + ref_sx_shape = sx_ref.shape + sx_valid = sx[: ref_sx_shape[0], : ref_sx_shape[1]] + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + + if return_transpose: + torch.testing.assert_close(amax_colwise, ref_amax_colwise_t, atol=0.0, rtol=0.0) + + torch.testing.assert_close(qx_t, qx_t_ref, atol=0.0, rtol=0.0) + + # Compare only the valid portion of transpose scale tensors + ref_sx_t_shape = sx_t_ref.shape + sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] + torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 256), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # Some larger tiles + (2048, 2048), + (1024, 2048), + (2048, 1024), + # Real shapes, + (8192, 5120), + (8192, 10240), + (8192, 2560), + (8192, 11328), + (8192, 512), + (8192, 3584), + (5120, 8192), + (10240, 8192), + (2560, 8192), + (11328, 8192), + (512, 8192), + (3584, 8192), + (4096, 16384), + (14336, 16384), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_rht_with_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +) -> None: + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=True, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (32, 128), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "return_transpose", [True, False], ids=["quantize_transpose", "skip_transpose"] +) +@pytest.mark.parametrize( + "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] +) +@pytest.mark.parametrize( + "with_random_sign_mask", [True, False], ids=["with_random_sign_mask", "no_random_sign_mask"] +) +def test_nvfp4_quantization_noncontiguous_inputs( + x_dtype: torch.dtype, + M: int, + N: int, + return_transpose: bool, + use_cpp_allocator: bool, + with_random_sign_mask: bool, +): + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + M=M, + N=N, + contiguous=False, + return_transpose=return_transpose, + use_cpp_allocator=use_cpp_allocator, + with_random_sign_mask=with_random_sign_mask, + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py new file mode 100755 index 000000000..0842de9ea --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_sr_quantize.py @@ -0,0 +1,238 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch import NVFP4Quantizer + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + +seed = 12345 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +_FP4_LUT = torch.tensor( + [ + 0.0, # 0: 0000 - zero + 0.5, # 1: 0001 - smallest positive normal + 1.0, # 2: 0010 + 1.5, # 3: 0011 + 2.0, # 4: 0100 + 3.0, # 5: 0101 + 4.0, # 6: 0110 + 6.0, # 7: 0111 - largest positive normal + -0.0, # 8: 1000 - negative zero + -0.5, # 9: 1001 - smallest negative normal + -1.0, # 10: 1010 + -1.5, # 11: 1011 + -2.0, # 12: 1100 + -3.0, # 13: 1101 + -4.0, # 14: 1110 + -6.0, # 15: 1111 - largest negative normal + ], + dtype=torch.float32, +) + + +def fp4_to_fp32(fp4: torch.Tensor) -> torch.Tensor: + # Convert FP4 indices to their corresponding floating point values + # Each index (0-15) represents a 4-bit FP4 value in E2M1 format + # Values based on the FP4 E2M1 specification + fp4_lut = _FP4_LUT.to(fp4.device) + return fp4_lut[fp4.to(torch.long)] + + +def dequantize_fp4(qx: torch.Tensor, sx: torch.Tensor, amax: torch.Tensor) -> torch.Tensor: + sf = sx.repeat_interleave(16, dim=1).view(torch.float8_e4m3fn).to(torch.float32) + dqx = fp4_to_fp32(unpack_fp4(qx)) + sf = sf[: dqx.shape[0], : dqx.shape[1]] + dequant = dqx * sf * (amax / (6.0 * 448)) + return dequant + + +def RHT(x: torch.Tensor) -> torch.Tensor: + def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [ + 1.0, + 1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + -1.0, + 1.0, + -1.0, + 1.0, + -1.0, + -1.0, + ], + dtype=torch.float32, + ) + + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + rht_dim = 16 + # Build H and scale + H = _build_hadamard_matrix(rht_dim, x.device, x.dtype) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + +def quantize_fp4( + x: torch.Tensor, use_stochastic_rounding: bool, use_2D: bool, use_RHT: bool +) -> torch.Tensor: + nvfp4_quantizer = NVFP4Quantizer( + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=use_RHT, + with_post_rht_amax=True, + stochastic_rounding=use_stochastic_rounding, + with_2d_quantization=use_2D, + ) + + x_nvfp4_sut = nvfp4_quantizer(x) + # Extract data from NVFP4Tensor + assert x_nvfp4_sut._rowwise_data is not None + qx: torch.Tensor = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._rowwise_scale_inv is not None + sx: torch.Tensor = x_nvfp4_sut._rowwise_scale_inv + assert x_nvfp4_sut._columnwise_data is not None + qx_t: torch.Tensor = x_nvfp4_sut._columnwise_data.view(dtype=torch.uint8) + assert x_nvfp4_sut._columnwise_scale_inv is not None + sx_t: torch.Tensor = x_nvfp4_sut._columnwise_scale_inv + + return qx, sx, qx_t, sx_t + + +def check_quantization_nvfp4_versus_reference( + x_dtype: torch.dtype, M: int, N: int, use_2D: bool, use_RHT: bool +) -> None: + device = "cuda" + torch.manual_seed(seed) + n_iters = 50 + + x = torch.randn((M, N), dtype=x_dtype, device=device) * 2 - 1 + y = x.t().contiguous() + if use_RHT: + y = RHT(y) + amax = torch.max(torch.abs(x)).float() + q_rn, s_rn, q_t_rn, s_t_rn = quantize_fp4( + x, use_stochastic_rounding=False, use_2D=use_2D, use_RHT=use_RHT + ) + dq_rn = dequantize_fp4(q_rn, s_rn, amax) + dq_t_rn = dequantize_fp4(q_t_rn, s_t_rn, amax) + error_rn = (dq_rn - x).float() + me_rn = torch.sqrt((error_rn * error_rn).mean()) + error_t_rn = (dq_t_rn - y).float() + me_t_rn = torch.sqrt((error_t_rn * error_t_rn).mean()) + sr_result = torch.zeros_like(x).float() + sr_t_result = torch.zeros_like(x).float().t().contiguous() + for i in range(n_iters): + q_sr, s_sr, q_t_sr, s_t_sr = quantize_fp4( + x, use_stochastic_rounding=True, use_2D=use_2D, use_RHT=use_RHT + ) + + dq_sr = dequantize_fp4(q_sr, s_sr, amax) + dq_t_sr = dequantize_fp4(q_t_sr, s_t_sr, amax) + + sr_result += dq_sr.float() + sr_t_result += dq_t_sr.float() + + # sr_result_tmp = sr_result / (i + 1) + # error_sr = (sr_result_tmp - x).float() + # me_sr = torch.sqrt((error_sr * error_sr).mean()) + # sr_t_result_tmp = sr_t_result / (i + 1) + # error_t_sr = (sr_t_result_tmp - y).float() + # me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + # print(f"Iteration {i}: RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + # print(f"Iteration {i}: RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + + # Get the mean result of the stochastic rounding + # It should be more accurate than the RN result + sr_result /= n_iters + error_sr = (sr_result - x).float() + me_sr = torch.sqrt((error_sr * error_sr).mean()) + sr_t_result /= n_iters + error_t_sr = (sr_t_result - y).float() + me_t_sr = torch.sqrt((error_t_sr * error_t_sr).mean()) + + print(f"RMSE SR: {me_sr:.3e} | RMSE RN: {me_rn:.3e}") + print(f"RMSE SR_t: {me_t_sr:.3e} | RMSE RN_t: {me_t_rn:.3e}") + assert me_sr < me_rn, "Stochastic rounding failed - error larger than the round to nearest." + assert me_t_sr < me_t_rn, "Stochastic rounding failed - error larger than the round to nearest." + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + (8192, 8192), + (8192, 8256), # to test the nonfused RHT path + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_2D", [False, True], ids=str) +@pytest.mark.parametrize("use_RHT", [False, True], ids=str) +def test_quantization_block_tiling_versus_reference( + x_dtype: torch.dtype, + use_2D: bool, + use_RHT: bool, + M: int, + N: int, +) -> None: + if x_dtype == torch.float32 and use_RHT: + pytest.skip("RHT is only supported with bfloat16") + check_quantization_nvfp4_versus_reference( + x_dtype=x_dtype, + use_2D=use_2D, + use_RHT=use_RHT, + M=M, + N=N, + ) diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 16e7feb1b..99a3af0d6 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -12,13 +12,15 @@ import pytest import torch +from typing import Optional + import transformer_engine.pytorch as te from utils import make_recipe # Check supported quantization schemes -fp8_available, reason_for_no_fp8 = te.fp8.FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = te.fp8.FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) # Test cases for loading checkpoint files @@ -65,16 +67,16 @@ def _make_module(name: str) -> torch.nn.Module: if name == "ops_linear": return te.ops.Linear(1, 1) if name == "linear.fp8": - with te.fp8_model_init(recipe=make_recipe("fp8")): + with te.quantized_model_init(recipe=make_recipe("fp8")): return te.Linear(16, 16) if name == "ops_linear.fp8": - with te.fp8_model_init(recipe=make_recipe("fp8")): + with te.quantized_model_init(recipe=make_recipe("fp8")): return te.ops.Linear(16, 16) if name == "linear.mxfp8": - with te.fp8_model_init(recipe=make_recipe("mxfp8")): + with te.quantized_model_init(recipe=make_recipe("mxfp8")): return te.Linear(32, 32) if name == "ops_linear.mxfp8": - with te.fp8_model_init(recipe=make_recipe("mxfp8")): + with te.quantized_model_init(recipe=make_recipe("mxfp8")): return te.ops.Linear(32, 32) raise ValueError(f"Unrecognized module name ({name})") diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 6035a6528..4e4c71e14 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -1,29 +1,46 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +import random import contextlib -import gc -import os -from typing import Iterable, Optional - import pytest +import os import torch - +from typing import Optional, List +from transformer_engine.pytorch.cpu_offload import ( + get_cpu_offload_context, + OffloadableLayerState, + DefaultOffloadSynchronizer, + start_offload, + mark_not_offload, +) +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends -from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported -from utils import ModelConfig, get_available_attention_backends +from utils import ModelConfig +import transformer_engine_torch as tex + +from torch.utils.cpp_extension import IS_HIP_EXTENSION # Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() -quantization_recipes: Optional[recipe.Recipe] = [None] +quantization_recipes: List[Optional[recipe.Recipe]] = [None] if fp8_available: quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) +if fp8_block_scaling_available: + quantization_recipes.append(recipe.Float8BlockScaling()) +if mxfp8_available: + quantization_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: + quantization_recipes.append(recipe.NVFP4BlockScaling()) + model_config = { "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), @@ -33,181 +50,716 @@ NUM_LAYERS = model_config["small"].num_layers EPSILON = model_config["small"].eps -# Flash attention saves some internal tensor for the backward pass -# that cannot be offloaded to CPU. -assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" +# Disable garbage collection to tests if there are reference cycles. +# We do not want them, because they can result in CUDA out of memory errors. +import gc -# Offloading is supported for attention only for fused and flash attention backends, -# so the use of bfloat16 is required. -# -# For the TransformerLayer, activation offloading with dropout is not supported, -# so we set hidden_dropout to 0.0. -model_types = { - "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), - "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), - "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), - "multihead_attention": lambda: te.MultiheadAttention( - SIZE, NUM_HEADS, params_dtype=torch.bfloat16 - ), - "transformer_layer": lambda: te.TransformerLayer( - SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 - ), - "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), - "layernorm_mlp_ops": lambda: te.ops.Sequential( - te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), - te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), - te.ops.GELU(), - te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), - ), -} +gc.disable() + + +class Utils: + tensor1 = torch.randn((1024, 1024), device="cuda", dtype=torch.bfloat16) + _B = 64 + _S = 256 + _H = 4 + _D = 256 + + @staticmethod + def long_job(stream: Optional[torch.cuda.Stream] = None): + NUM_ITERS = 6000 + if stream is None: + stream = torch.cuda.current_stream() + + with torch.cuda.stream(stream): + for i in range(NUM_ITERS): + Utils.tensor1.normal_() + + @staticmethod + def measure_time(func): + import time + + torch.cuda.synchronize() + start = time.time() + func() + torch.cuda.synchronize() + end = time.time() + return (end - start) * 1000 + + @staticmethod + def get_cuda_memory_mb(): + return torch.cuda.memory_allocated() / (1024**2) + + @staticmethod + def get_max_cuda_memory_mb(): + return torch.cuda.max_memory_allocated() / (1024**2) + + @staticmethod + def get_cpu_memory_mb() -> float: + import psutil, os + + return psutil.Process(os.getpid()).memory_info().rss / (1024**2) + + @staticmethod + def get_layer_names(): + return [ + "linear", + "layernorm_linear", + "layernorm_mlp", + "grouped_linear", + "multihead_attention", + "transformer_layer", + "linear_op", + "layernorm_mlp_ops", + ] + + @staticmethod + def create_layer(layer_type: str): + if layer_type == "linear": + return te.Linear(Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "layernorm_linear": + return te.LayerNormLinear(Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "layernorm_mlp": + return te.LayerNormMLP(Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "multihead_attention": + return te.MultiheadAttention( + Utils._D, Utils._H, attention_dropout=0.0, params_dtype=torch.bfloat16 + ) + elif layer_type == "grouped_linear": + return te.GroupedLinear(Utils._H, Utils._D, Utils._D, params_dtype=torch.bfloat16) + elif layer_type == "transformer_layer": + return te.TransformerLayer( + Utils._D, + Utils._D, + Utils._H, + attention_dropout=0.0, + hidden_dropout=0.0, + params_dtype=torch.bfloat16, + ) + elif layer_type == "linear_op": + return te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16) + elif layer_type == "layernorm_mlp_ops": + return te.ops.Sequential( + te.ops.LayerNorm(Utils._D, dtype=torch.bfloat16), + te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(Utils._D, Utils._D, dtype=torch.bfloat16), + ) + else: + raise ValueError(f"Unknown layer type: {layer_type}") + + @staticmethod + def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) -> torch.Tensor: + shape = (Utils._B, Utils._S, Utils._D) + tensor = torch.randn(shape, device="cuda", dtype=torch.bfloat16) + if recipe is None: + tensor = tensor.requires_grad_() if requires_grad else tensor + return tensor + elif recipe.delayed(): + quantizer = te.tensor.float8_tensor.Float8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + scale=torch.tensor([1.0], device="cuda"), + amax=torch.tensor([1.0], device="cuda"), + ) + return quantizer(tensor) + elif recipe.float8_current_scaling(): + quantizer = te.tensor.float8_tensor.Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, device="cuda" + ) + return quantizer(tensor) + elif recipe.float8_block_scaling(): + quantizer = te.tensor.float8_blockwise_tensor.Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, rowwise=True, columnwise=True + ) + return quantizer(tensor) + elif recipe.mxfp8(): + quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + return quantizer(tensor) + elif recipe.nvfp4(): + quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer() + return quantizer(tensor) + + @staticmethod + def create_recipe_ctx(recipe: Optional[recipe.Recipe]): + if recipe is None: + return lambda: contextlib.nullcontext() + else: + return lambda: te.fp8_autocast(fp8_recipe=recipe) + + @staticmethod + def get_tensor_size_mb(tensor): + if tensor is None: + return 0 + if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): + return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + else: + return tensor.numel() * tensor.element_size() / (1024**2) + + @staticmethod + def memory_leak_check(): + # Should be called before each test. + # Only cublas workspaces and some global tensors are allowed to be allocated. + # All other allocations should be released. + # This is a simple check to catch memory leaks. + if Utils.get_cuda_memory_mb() > 1000: + memory_num = Utils.get_cuda_memory_mb() + import gc + + gc.collect() # We want next test to be run with clean state. + gc.disable() + raise RuntimeError(f"Memory leak: {memory_num} MB") + + +class TestsOffloadableLayerState: + @pytest.mark.parametrize("random_num_tensors", [True, False]) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_general(self, random_num_tensors, recipe): + """ + Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, + for each layer offload random number of random tensors. + Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors. + """ + Utils.memory_leak_check() + NUM_ITERATIONS = 10 + + stream = torch.cuda.Stream() + + offload_layer_state = OffloadableLayerState( + offload_stream=stream, + ) + for _ in range(NUM_ITERATIONS): + original_tensors = [] + tensors_ids = [] + NUM_TENSORS = random.choice([1, 20]) if random_num_tensors else 1 + for _ in range(NUM_TENSORS): + tensor = Utils.create_tensor(recipe) + original_tensors.append(tensor) + tensor_id = offload_layer_state.push_tensor(tensor) + assert tensor.device.type == "cuda" + tensors_ids.append(tensor_id) + + offload_layer_state.start_offload() + offload_layer_state.release_activation_forward_gpu_memory() + offload_layer_state.start_reload() + + for j in range(len(tensors_ids)): + tensor_gpu = offload_layer_state.pop_tensor(tensors_ids[j]) + assert tensor_gpu.device.type == "cuda" + assert tensor_gpu.shape == original_tensors[j].shape + assert tensor_gpu.dtype == original_tensors[j].dtype + torch.testing.assert_close(tensor_gpu, original_tensors[j]) + offload_layer_state.release_all_memory() + torch.cuda.synchronize() + + def test_offload_base_tensor(self): + Utils.memory_leak_check() + stream = torch.cuda.Stream() + offload_layer_state = OffloadableLayerState( + offload_stream=stream, + ) + init_cuda_memory = Utils.get_cuda_memory_mb() + x = Utils.create_tensor(None) + x_size = Utils.get_tensor_size_mb(x) + x_1 = x[::2] + x_2 = x[1::2] + + start_offload(x_1, offload_base_tensor=True) + start_offload(x_2, offload_base_tensor=True) + x1_id = offload_layer_state.push_tensor(x_1) + x2_id = offload_layer_state.push_tensor(x_2) + del x_1, x_2 + offload_layer_state.start_offload() + offload_layer_state.release_activation_forward_gpu_memory() + + assert offload_layer_state.get_offloaded_total_size_mb() == pytest.approx(x_size, 0.1) + + offload_layer_state.start_reload() + x_1 = offload_layer_state.pop_tensor(x1_id) + x_2 = offload_layer_state.pop_tensor(x2_id) + assert x_1.device.type == "cuda" + assert x_2.device.type == "cuda" + + assert torch.allclose(x_1, x[::2]) + assert torch.allclose(x_2, x[1::2]) + del x + + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + x_size, 0.1) + + +class TestsDefaultOffloadSynchronizer: + @pytest.mark.parametrize("random_num_tensors", [True, False]) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_general(self, random_num_tensors, recipe): + """ + Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, + for each layer offload random number of random tensors. + Then do backward pass for each layer, and check if reloaded tensors are equal to original tensors. + """ + Utils.memory_leak_check() + NUM_LAYERS = 10 + NUM_ITERATIONS = 10 + + offload_synchronizer = DefaultOffloadSynchronizer( + num_layers=NUM_LAYERS, + num_offloaded_layers=NUM_LAYERS - 1, + ) -def _make_input() -> torch.Tensor: - """Generate random input tensor.""" - return torch.randn( - (128, SIZE, SIZE), - dtype=torch.bfloat16, - device="cuda", - requires_grad=True, - ) - - -def _warmup_model( - modules: Iterable[torch.nn.Module], - quantization_recipe: Optional[recipe.Recipe], -) -> None: - """Perform forward and backward pass""" - tensor = _make_input() - for module in modules: - with te.fp8_autocast( - enabled=quantization_recipe is not None, - fp8_recipe=quantization_recipe, + for _ in range(NUM_ITERATIONS): + original_tensors = [] + tensors_ids = [] + layer_ids = [] + + for i in range(NUM_LAYERS): + NUM_LAYER_TENSORS = random.randint(1, 10) if random_num_tensors else 1 + layer_tensors = [] + layer_tensors_ids = [] + layer_id = offload_synchronizer.fwd_step() + for _ in range(NUM_LAYER_TENSORS): + tensor = Utils.create_tensor(recipe) + layer_tensors.append(tensor) + tensor_id = offload_synchronizer.push_tensor(tensor) + assert tensor.device.type == "cuda" + layer_tensors_ids.append(tensor_id) + layer_ids.append(layer_id) + tensors_ids.append(layer_tensors_ids) + original_tensors.append(layer_tensors) + for i in range(NUM_LAYERS - 1, -1, -1): + offload_synchronizer.bwd_step(layer_ids[i]) + for j in range(len(tensors_ids[i])): + tensor_gpu = offload_synchronizer.pop_tensor(tensors_ids[i][j]) + assert tensor_gpu.device.type == "cuda" + assert tensor_gpu.shape == original_tensors[i][j].shape + assert tensor_gpu.dtype == original_tensors[i][j].dtype + torch.testing.assert_close(tensor_gpu, original_tensors[i][j]) + offload_synchronizer.finish_part_of_bwd() + torch.cuda.synchronize() + + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_memory(self, recipe): + torch.cuda.synchronize() + Utils.memory_leak_check() + NUM_LAYERS = 10 + + torch.cuda.reset_peak_memory_stats() + + offload_synchronizer = DefaultOffloadSynchronizer( + num_layers=NUM_LAYERS, + num_offloaded_layers=NUM_LAYERS - 1, + ) + + init_cuda_memory = Utils.get_cuda_memory_mb() + + tensor_ids = [] + + torch.cuda.synchronize() + for _ in range(NUM_LAYERS): + offload_synchronizer.fwd_step() + tensor = Utils.create_tensor(recipe) + tensor_size = Utils.get_tensor_size_mb(tensor) + tensor_id = offload_synchronizer.push_tensor(tensor) + assert tensor.device.type == "cuda" + tensor_ids.append(tensor_id) + del tensor, tensor_id + torch.cuda.synchronize() + + if recipe is None: + assert Utils.get_max_cuda_memory_mb() == pytest.approx( + init_cuda_memory + tensor_size, 0.1 + ) + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1) + + for i in range(NUM_LAYERS - 1, -1, -1): + offload_synchronizer.bwd_step(i) + tensor_gpu = offload_synchronizer.pop_tensor(tensor_ids[i]) + assert tensor_gpu.device.type == "cuda" + del tensor_gpu, tensor_ids[i] + offload_synchronizer.finish_part_of_bwd() + + del tensor_ids + torch.cuda.synchronize() + + if recipe is None: + assert Utils.get_max_cuda_memory_mb() == pytest.approx( + init_cuda_memory + tensor_size, 0.1 + ) + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_multiple_tensor_offload(self, recipe): + Utils.memory_leak_check() + init_cpu_memory = Utils.get_cpu_memory_mb() + init_cuda_memory = Utils.get_cuda_memory_mb() + offload_synchronizer = DefaultOffloadSynchronizer( + num_layers=2, + num_offloaded_layers=1, + ) + x1 = Utils.create_tensor(recipe) + x_size = Utils.get_tensor_size_mb(x1) + offload_synchronizer.fwd_step() + offload_synchronizer.push_tensor(x1) + offload_synchronizer.push_tensor(x1) + offload_synchronizer.push_tensor(x1) + offload_synchronizer.fwd_step() + # Only one copy of tensor on cpu is allocated. + assert Utils.get_cpu_memory_mb() == pytest.approx(init_cpu_memory + 1 * x_size, 0.1) + del x1 + offload_synchronizer.bwd_step(1) + offload_synchronizer.bwd_step(0) + offload_synchronizer.finish_part_of_bwd() + + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + + +class TestTELayers: + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_sanity(self, layer_type, recipe): + Utils.memory_leak_check() + + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() ): - tensor = module(tensor) - tensor.sum().backward() - - -def _estimate_cached_weight_size( - model_name: str, - modules: Iterable[torch.nn.Module], - quantization_recipe: Optional[recipe.Recipe], -) -> float: - """Calculate the memory (in MiB) needed for weight caching.""" - - # The weight params are cached directly for unquantized compute - if quantization_recipe is None: - return 0 - - # Count number of weight param elements - param_elements = 0 - for module in modules: - for param in module.parameters(): - if param.dim() == 2: - param_elements += param.numel() - - # FP8 tensor-scaling caches one byte per element - if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): - if not is_non_tn_fp8_gemm_supported() and model_name not in ( - "linear_op", - "layernorm_mlp_ops", + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + + recipe_ctx = Utils.create_recipe_ctx(recipe) + init_cuda_memory = Utils.get_cuda_memory_mb() + OFFLOAD_LAYERS = 6 + NUM_LAYERS = 10 + offload_ctx, sync_function = get_cpu_offload_context( + enabled=True, + num_layers=OFFLOAD_LAYERS, + model_layers=NUM_LAYERS, + ) + layers = [Utils.create_layer(layer_type) for _ in range(NUM_LAYERS)] + inp = Utils.create_tensor(None) + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + out = inp + for i in range(NUM_LAYERS): + with offload_ctx, recipe_ctx(): + # Ops-based layers don't support is_first_microbatch parameter + if layer_type in ["linear_op", "layernorm_mlp_ops"]: + out = layers[i](out, **m_splits) + else: + out = layers[i](out, is_first_microbatch=False, **m_splits) + out = sync_function(out) + out.sum().backward() + torch.cuda.synchronize() + del out, inp, layers + + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_memory(self, layer_type, recipe): + Utils.memory_leak_check() + + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() ): - # Modules do not deallocate FP8 transpose for weights - return 2 * param_elements / 1024**2 - return param_elements / 1024**2 + pytest.skip("Fusible operations do not support FP8 block scaling recipe") - # MXFP8 caches one data byte per element and one scale byte per 32 - # elements - if quantization_recipe.mxfp8(): - if model_name not in ("linear_op", "layernorm_mlp_ops"): - # Modules do not deallocate column-wise MXFP8 data for weights - return 2 * param_elements * (1 + 1 / 32) / 1024**2 - return param_elements * (1 + 1 / 32) / 1024**2 + offload_ctx, sync_function = get_cpu_offload_context( + enabled=True, + num_layers=1, + model_layers=2, + offload_activations=True, + offload_weights=False, + ) + recipe_ctx = Utils.create_recipe_ctx(recipe) + layer = Utils.create_layer(layer_type) + inp = Utils.create_tensor(None) + + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + + # Ops-based layers don't support is_first_microbatch parameter + is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"] + + with recipe_ctx(): + if is_ops_layer: + out = layer(inp, **m_splits) + else: + out = layer(inp, is_first_microbatch=True, **m_splits) + out.sum().backward() + + del inp + init_cuda_memory = Utils.get_cuda_memory_mb() + + # run layer without offload + inp = Utils.create_tensor(None) + with recipe_ctx(): + if is_ops_layer: + out = layer(inp, **m_splits) + else: + out = layer(inp, is_first_microbatch=False, **m_splits) + with recipe_ctx(): + out = out + 1 + del inp + cuda_memory_no_offload = Utils.get_cuda_memory_mb() + + out.sum().backward() + # run layer with offload + inp = Utils.create_tensor(None) + with offload_ctx, recipe_ctx(): + if is_ops_layer: + out = layer(inp, **m_splits) + else: + out = layer(inp, is_first_microbatch=False, **m_splits) + out = sync_function(out) + with offload_ctx, recipe_ctx(): + out = out + 1 + out = sync_function(out) + del inp + assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) + offloaded_memory_cpu = offload_ctx.offload_synchronizer.get_offloaded_total_size_mb() + + # This assertion verifies that the memory used by tensors on the CPU matches the memory saved from a layer. + # It helps catch cases where an offloaded tensor still has a live pointer, which would + # cause an unnecessary copy to the CPU and prevent GPU memory from being released. + assert Utils.get_cuda_memory_mb() + offloaded_memory_cpu == pytest.approx( + cuda_memory_no_offload, 0.1 + ) + out.sum().backward() + + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("recipe", quantization_recipes) + def test_manual_synchronization(self, recipe, layer_type): + Utils.memory_leak_check() + + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() + ): + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + + offload_ctx, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, + model_layers=6, + offload_activations=True, + manual_synchronization=True, + ) + layer_1 = Utils.create_layer(layer_type) + layer_2 = Utils.create_layer(layer_type) + inp1 = Utils.create_tensor(None) + inp2 = Utils.create_tensor(None) + + recipe_ctx = Utils.create_recipe_ctx(recipe) - raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + + init_cuda_memory = Utils.get_cuda_memory_mb() + + # 1 fwd + with offload_ctx, recipe_ctx(): + out_1 = layer_1(inp1, **m_splits) + out_1 = sync_function(out_1) + + with offload_ctx, recipe_ctx(): + out_2 = layer_2(inp2, **m_splits) + out_2 = sync_function(out_2) + + mark_not_offload(out_1, out_2) + + del inp1, inp2 + + memory_before_offload = Utils.get_cuda_memory_mb() + manual_controller.start_offload_layer(0) + manual_controller.release_activation_forward_gpu_memory(0) + manual_controller.start_offload_layer(1) + manual_controller.release_activation_forward_gpu_memory(1) + memory_after_offload = Utils.get_cuda_memory_mb() + assert memory_after_offload + EPSILON < memory_before_offload + + manual_controller.start_reload_layer(0) + manual_controller.start_reload_layer(1) + + memory_after_reload = Utils.get_cuda_memory_mb() + assert memory_after_reload == pytest.approx(memory_before_offload, 0.1) + + out_1.sum().backward() + out_2.sum().backward() + + @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) + @pytest.mark.parametrize("use_cuda_graphs", [True, False]) + @pytest.mark.parametrize("retain_pinned_cpu_buffers", [True, False]) + @pytest.mark.parametrize("backend", ["FlashAttention", "FusedAttention", "UnfusedAttention"]) + def test_numerics( + self, + recipe, + layer_type, + use_cuda_graphs, + backend, + retain_pinned_cpu_buffers, + ): + # Skip ops-based layers with Float8BlockScaling recipe + if ( + layer_type in ["linear_op", "layernorm_mlp_ops"] + and recipe is not None + and recipe.float8_block_scaling() + ): + pytest.skip("Fusible operations do not support FP8 block scaling recipe") + recipe_ctx = Utils.create_recipe_ctx(recipe) -def _measure_cached_memory( - modules: Iterable[torch.nn.Module], - quantization_recipe: Optional[recipe.Recipe], - cpu_offload: bool, -) -> float: - """Measure the growth in allocated GPU memory in MiB after a model forward pass. + if use_cuda_graphs and not retain_pinned_cpu_buffers: + pytest.skip( + "Cuda graphs are not yet supported with cpu offloading when" + " retain_pinned_cpu_buffers is False." + ) - Memory measurement excludes the input and output tensors. + if backend == "FusedAttention" and use_cuda_graphs: + pytest.skip( + "Fused attention + cuda graphs is temporarily broken, not because of cpu offloading" + ) - """ + if (IS_HIP_EXTENSION + and backend == "FusedAttention" + and not use_cuda_graphs + and layer_type in ("multihead_attention", "transformer_layer") + ): + pytest.skip("No dot product attention backend is available for the provided inputs") - # Reset memory - gc.collect() - torch.cuda.empty_cache() + os.environ["NVTE_FLASH_ATTN"] = "0" + os.environ["NVTE_FUSED_ATTN"] = "0" + os.environ["NVTE_UNFUSED_ATTN"] = "0" - # Context and sync function for CPU offloading - if cpu_offload: - offload_context, sync_function = te.get_cpu_offload_context( + if backend == "FlashAttention": + os.environ["NVTE_FLASH_ATTN"] = "1" + elif backend == "FusedAttention": + os.environ["NVTE_FUSED_ATTN"] = "1" + elif backend == "UnfusedAttention": + os.environ["NVTE_UNFUSED_ATTN"] = "1" + + offload_ctx, sync_function = get_cpu_offload_context( enabled=True, - num_layers=len(modules), - model_layers=len(modules) + 1, + num_layers=1, + model_layers=2, offload_activations=True, offload_weights=False, + retain_pinned_cpu_buffers=retain_pinned_cpu_buffers, ) - else: - offload_context = contextlib.nullcontext() - sync_function = lambda x: x - - # Forward pass, with dummy step to trigger offload for last module - inp = _make_input() - tensor = inp - memory_before_forward = torch.cuda.memory_allocated() / (1024**2) - for module in modules: - with te.fp8_autocast( - enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe - ), offload_context: - tensor = module(tensor) - tensor = sync_function(tensor) - with offload_context: - tensor = tensor.clone() - tensor = sync_function(tensor) - memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) - - # Backward pass - tensor.sum().backward() - torch.cuda.synchronize() - - # Memory usage in MiB - return memory_after_forward - memory_before_forward - - -@pytest.mark.parametrize("quantization_recipe", quantization_recipes) -@pytest.mark.parametrize("model_name", model_types.keys()) -def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: - """Check that CPU offloading runs and has expected memory usage.""" - - # Construct model - modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] - if model_name in ["multihead_attention", "transformer_layer"]: - available_backends, *_ = get_available_attention_backends( - model_config["small"], - qkv_dtype=torch.bfloat16, - qkv_layout="sbhd_sbhd_sbhd", + + class Callable(torch.nn.Module): + def __init__(self, offload_ctx=None, sync_function=None): + super().__init__() + self.layers = torch.nn.ModuleList( + [Utils.create_layer(layer_type) for _ in range(2)] + ) + self.offload_ctx = offload_ctx + self.sync_function = sync_function + + def forward(self, x): + m_splits = ( + {"m_splits": [Utils._B * Utils._S // Utils._H] * Utils._H} + if layer_type == "grouped_linear" + else {} + ) + is_ops_layer = layer_type in ["linear_op", "layernorm_mlp_ops"] + for layer in self.layers: + with self.offload_ctx, recipe_ctx(): + if is_ops_layer: + x = layer(x, **m_splits) + else: + x = layer(x, is_first_microbatch=False, **m_splits) + if self.sync_function is not None: + x = self.sync_function(x) + return x + + callable_offload = Callable(offload_ctx=offload_ctx, sync_function=sync_function) + callable_no_offload = Callable(offload_ctx=contextlib.nullcontext(), sync_function=None) + + # copy parameters + for param_offload, param_no_offload in zip( + callable_offload.parameters(), callable_no_offload.parameters() + ): + param_offload.data.copy_(param_no_offload.data) + + x = Utils.create_tensor(None) + + if use_cuda_graphs: + callable_offload = te.make_graphed_callables( + callable_offload, + (x,), + enabled=recipe is not None, + recipe=(Utils.create_recipe_ctx(recipe) if recipe is not None else None), + ) + + # warm up (for example to compute sf for delayed scaling) + for _ in range(4): + out = callable_offload(x) + out.sum().backward() + out = callable_no_offload(x) + out.sum().backward() + + callable_offload.zero_grad(set_to_none=True) + out_offload = callable_offload(x) + out_offload.sum().backward() + + # save out and gradients + offload_outs = [out_offload] + for param in callable_offload.parameters(): + offload_outs.append(param.detach().clone()) + + torch.cuda.reset_peak_memory_stats() + out_no_offload = callable_no_offload(x) + out_no_offload.sum().backward() + + # collect gradients + no_offload_outs = [out_no_offload] + for param in callable_no_offload.parameters(): + no_offload_outs.append(param.detach().clone()) + + # check if tensors are the same + for i in range(len(offload_outs)): + assert torch.allclose(offload_outs[i], no_offload_outs[i]), f"Error in tensor {i}." + + torch.cuda.synchronize() + + def test_example_from_doc(self): + offload_stream = torch.cuda.Stream() + num_layers = 10 + layers = [Utils.create_layer("transformer_layer") for _ in range(num_layers)] + inp = [Utils.create_tensor(None) for _ in range(num_layers)] + out = [None] * num_layers + cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, + model_layers=num_layers, + manual_synchronization=True, + offload_stream=offload_stream, ) - _, fused_attn_supported, _ = available_backends - if not fused_attn_supported: - pytest.skip("Fused attention backend not available.") - os.environ["NVTE_FLASH_ATTN"] = "0" - _attention_backends["backend_selection_requires_update"] = True - - # Warmup - _warmup_model(modules_list, quantization_recipe) - - # Measure cached memory after forward pass - memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) - memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) - - # Check for expected memory usage - assert memory_with_offload < memory_without_offload - memory_from_cached_weights = _estimate_cached_weight_size( - model_name, - modules_list, - quantization_recipe, - ) - assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON + + for i in range(num_layers): + with cpu_offload_context: + out[i] = layers[i].forward(inp[i]) + out[i] = sync_function(out[i]) + manual_controller.start_offload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + manual_controller.release_activation_forward_gpu_memory(i) + + for i in range(num_layers - 1, -1, -1): + # these calls are intended to be done in the backward pass + manual_controller.start_reload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + out[i].sum().backward() diff --git a/tests/pytorch/test_cpu_offloading_v1.py b/tests/pytorch/test_cpu_offloading_v1.py new file mode 100644 index 000000000..07091ee7a --- /dev/null +++ b/tests/pytorch/test_cpu_offloading_v1.py @@ -0,0 +1,217 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import contextlib +import gc +import os +from typing import Iterable, Optional + +import pytest +import torch + +import transformer_engine.pytorch as te +from transformer_engine.common import recipe +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported +from utils import ModelConfig, get_available_attention_backends + +# Check supported quantization schemes +fp8_available = te.is_fp8_available() +mxfp8_available = te.is_mxfp8_available() + +quantization_recipes: Optional[recipe.Recipe] = [None] +if fp8_available: + quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) + +model_config = { + "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), +} +SIZE = model_config["small"].hidden_size +NUM_HEADS = model_config["small"].num_heads +NUM_LAYERS = model_config["small"].num_layers +EPSILON = model_config["small"].eps + +# Flash attention saves some internal tensor for the backward pass +# that cannot be offloaded to CPU. +assert os.getenv("NVTE_FLASH_ATTN", "1") == "0" + +# CPU offload v1 code path is enabled +assert os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" + +# Offloading is supported for attention only for fused and flash attention backends, +# so the use of bfloat16 is required. +# +# For the TransformerLayer, activation offloading with dropout is not supported, +# so we set hidden_dropout to 0.0. +model_types = { + "linear": lambda: te.Linear(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_mlp": lambda: te.LayerNormMLP(SIZE, SIZE, params_dtype=torch.bfloat16), + "layernorm_linear": lambda: te.LayerNormLinear(SIZE, SIZE, params_dtype=torch.bfloat16), + "multihead_attention": lambda: te.MultiheadAttention( + SIZE, NUM_HEADS, params_dtype=torch.bfloat16 + ), + "transformer_layer": lambda: te.TransformerLayer( + SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 + ), + "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + "layernorm_mlp_ops": lambda: te.ops.Sequential( + te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + ), +} + + +def _make_input() -> torch.Tensor: + """Generate random input tensor.""" + return torch.randn( + (128, SIZE, SIZE), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) + + +def _warmup_model( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> None: + """Perform forward and backward pass""" + tensor = _make_input() + for module in modules: + with te.autocast( + enabled=quantization_recipe is not None, + recipe=quantization_recipe, + ): + tensor = module(tensor) + tensor.sum().backward() + + +def _estimate_cached_weight_size( + model_name: str, + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> float: + """Calculate the memory (in MiB) needed for weight caching.""" + + # The weight params are cached directly for unquantized compute + if quantization_recipe is None: + return 0 + + # Count number of weight param elements + param_elements = 0 + for module in modules: + for param in module.parameters(): + if param.dim() == 2: + param_elements += param.numel() + + # FP8 tensor-scaling caches one byte per element + if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): + if not is_non_tn_fp8_gemm_supported() and model_name not in ( + "linear_op", + "layernorm_mlp_ops", + ): + # Modules do not deallocate FP8 transpose for weights + return 2 * param_elements / 1024**2 + return param_elements / 1024**2 + + # MXFP8 caches one data byte per element and one scale byte per 32 + # elements + if quantization_recipe.mxfp8(): + if model_name not in ("linear_op", "layernorm_mlp_ops"): + # Modules do not deallocate column-wise MXFP8 data for weights + return 2 * param_elements * (1 + 1 / 32) / 1024**2 + return param_elements * (1 + 1 / 32) / 1024**2 + + raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + + +def _measure_cached_memory( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], + cpu_offload: bool, +) -> float: + """Measure the growth in allocated GPU memory in MiB after a model forward pass. + + Memory measurement excludes the input and output tensors. + + """ + + # Reset memory + gc.collect() + torch.cuda.empty_cache() + + # Context and sync function for CPU offloading + if cpu_offload: + offload_context, sync_function = te.get_cpu_offload_context( + enabled=True, + num_layers=len(modules), + model_layers=len(modules) + 1, + offload_activations=True, + offload_weights=False, + ) + else: + offload_context = contextlib.nullcontext() + sync_function = lambda x: x + + # Forward pass, with dummy step to trigger offload for last module + inp = _make_input() + tensor = inp + memory_before_forward = torch.cuda.memory_allocated() / (1024**2) + for module in modules: + with te.autocast( + enabled=quantization_recipe is not None, recipe=quantization_recipe + ), offload_context: + tensor = module(tensor) + tensor = sync_function(tensor) + with offload_context: + tensor = tensor.clone() + tensor = sync_function(tensor) + memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) + + # Backward pass + tensor.sum().backward() + torch.cuda.synchronize() + + # Memory usage in MiB + return memory_after_forward - memory_before_forward + + +@pytest.mark.parametrize("quantization_recipe", quantization_recipes) +@pytest.mark.parametrize("model_name", model_types.keys()) +def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: + """Check that CPU offloading runs and has expected memory usage.""" + + # Construct model + modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] + if model_name in ["multihead_attention", "transformer_layer"]: + available_backends, *_ = get_available_attention_backends( + model_config["small"], + qkv_dtype=torch.bfloat16, + qkv_layout="sbhd_sbhd_sbhd", + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("Fused attention backend not available.") + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + + # Warmup + _warmup_model(modules_list, quantization_recipe) + + # Measure cached memory after forward pass + memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) + memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) + + # Check for expected memory usage + assert memory_with_offload < memory_without_offload + memory_from_cached_weights = _estimate_cached_weight_size( + model_name, + modules_list, + quantization_recipe, + ) + assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 90e624c94..eacbf5168 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Iterable, List, Union +from typing import Callable, Dict, Iterable, List, Tuple, Union import pytest import torch @@ -13,31 +13,81 @@ Linear, MultiheadAttention, TransformerLayer, - fp8_autocast, - fp8_model_init, + autocast, + quantized_model_init, make_graphed_callables, + is_fp8_available, + is_fp8_block_scaling_available, + is_mxfp8_available, + is_bf16_available, ) -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe from utils import ModelConfig, reset_rng_states # Check if FP8 is supported. -fp8_available, _ = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() -mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +fp8_available = is_fp8_available() +fp8_block_scaling_available = is_fp8_block_scaling_available() +mxfp8_available = is_mxfp8_available() # Reset RNG states. reset_rng_states() model_configs = { - "small": ModelConfig(32, 2, 2, 32), + "small": ModelConfig(2, 32, 2, 32), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -46,7 +96,7 @@ # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher dtypes.append(torch.bfloat16) @@ -110,6 +160,20 @@ def get_outputs( return values +def reset_graphs( + graphed_callables: Union[Callable, Tuple[Callable, ...], Dict[Tuple[int, int], Callable]], +) -> None: + """Reset CUDA graphs.""" + if isinstance(graphed_callables, tuple) or isinstance(graphed_callables, list): + for callable in graphed_callables: + callable.reset() + elif isinstance(graphed_callables, dict): + for callable in graphed_callables.values(): + callable.reset() + else: + graphed_callables.reset() + + class _Sequential(torch.nn.Sequential): """Sequential model that forwards keyword arguments to modules""" @@ -154,7 +218,7 @@ def _test_cuda_graphs( fp8_weight_caching = False # Create modules. - with fp8_model_init(enabled=fp8_params, recipe=fp8_recipe): + with quantized_model_init(enabled=fp8_params, recipe=fp8_recipe): if module == "transformer": modules = [ TransformerLayer( @@ -234,9 +298,9 @@ def _test_cuda_graphs( model, (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, - fp8_enabled=fp8, - fp8_weight_caching=fp8_weight_caching, - fp8_recipe=fp8_recipe, + enabled=fp8, + cache_quantized_params=fp8_weight_caching, + recipe=fp8_recipe, ) elif graph_mode == "individual": # Graph individual modules. @@ -245,9 +309,9 @@ def _test_cuda_graphs( module, (generate_data(model_config, dtype, warmup=True),), num_warmup_iters=10, - fp8_enabled=fp8, - fp8_weight_caching=fp8_weight_caching, - fp8_recipe=fp8_recipe, + enabled=fp8, + cache_quantized_params=fp8_weight_caching, + recipe=fp8_recipe, ) for module in modules ] @@ -264,7 +328,7 @@ def _test_cuda_graphs( for grad_accumulation_step in range(2): input_ = generate_data(model_config, dtype) grad_output = generate_data(model_config, dtype, requires_grad=False) - with fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=fp8, recipe=fp8_recipe): kwargs = {} if fp8_weight_caching: kwargs["is_first_microbatch"] = grad_accumulation_step == 0 @@ -272,13 +336,18 @@ def _test_cuda_graphs( output.backward(grad_output) optimizer.step() - return get_outputs(model, output) + outputs = get_outputs(model, output) + if graph_mode == "full": + reset_graphs(model) + elif graph_mode == "individual": + reset_graphs(modules) + return outputs @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) def test_make_graphed_callables( *, module: str, @@ -295,8 +364,18 @@ def test_make_graphed_callables( pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": - pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") + if fp8 and (fp8_recipe.float8_block_scaling() or fp8_recipe.nvfp4()) and module == "linear_op": + pytest.skip( + f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" + ) + if fp8 and fp8_recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe" + f" {fp8_recipe.__class__.__name__}" + ) + if fp8_params: + pytest.skip("NVFP4 params not supported") # Run model with different CUDA graph settings. model_config = model_configs[model_config] @@ -334,17 +413,19 @@ def test_make_graphed_callables( "module", _test_make_graphed_callables_with_fp8_weight_caching_modules, ) +@pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) def test_make_graphed_callables_with_fp8_weight_caching( *, module: str, + dtype: torch.dtype, fp8_params: bool, fp8_recipe: recipe.Recipe, ) -> None: test_make_graphed_callables( module=module, - dtype=torch.float32, + dtype=dtype, fp8_params=fp8_params, fp8_recipe=fp8_recipe, fp8_weight_caching=True, @@ -396,7 +477,7 @@ def _test_cuda_graphs_with_dot_product_attention( model, generate_data_for_dot_product_attention(model_config, dtype, warmup=True), num_warmup_iters=10, - fp8_enabled=False, + enabled=False, ) # Forward and backward passes. @@ -406,7 +487,10 @@ def _test_cuda_graphs_with_dot_product_attention( output = model(*inputs) output.backward(grad_output) - return get_outputs(model, output) + outputs = get_outputs(model, output) + if with_graph: + reset_graphs(model) + return outputs @pytest.mark.parametrize("dtype", dtypes) @@ -491,7 +575,10 @@ def _test_cuda_graphs_with_kwargs( output.backward(grad_output) optimizer.step() - return get_outputs(model, output) + outputs = get_outputs(model, output) + if with_graph: + reset_graphs(model) + return outputs def test_make_graphed_callables_with_kwargs( @@ -606,7 +693,10 @@ def backward(layer_idx: int, microbatch_idx: int): optimizer.step() outputs = [y for _, y in sorted(outputs.items())] - return get_outputs(model, outputs) + outputs = get_outputs(model, outputs) + if with_graph: + reset_graphs(layer_forwards) + return outputs def test_make_graphed_callables_with_interleaved_pipeline_parallelism( diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py new file mode 100644 index 000000000..64f1c3d15 --- /dev/null +++ b/tests/pytorch/test_custom_recipe.py @@ -0,0 +1,332 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.common import recipe +from transformer_engine.pytorch import ( + autocast, + Linear, + LayerNormLinear, + LayerNormMLP, + GroupedLinear, + Float8CurrentScalingQuantizer, +) +import transformer_engine.pytorch.ops as te_ops +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( + nvfp4_ref_rht_2d_quantizer_factory, +) + + +@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear"]) +def test_custom_recipe_sanity_modules_nvfp4(module_type): + """Test modules with NVFP4 custom recipe support""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + # Simple linear layer with dims divisible by 16 + in_features = 64 + out_features = 64 + batch = 32 + + if module_type == "Linear": + model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + elif module_type == "LayerNormLinear": + model = LayerNormLinear( + in_features, out_features, params_dtype=torch.bfloat16, bias=False + ).cuda() + else: # OpsLinear + model = te_ops.Linear( + in_features, out_features, device="cuda", dtype=torch.bfloat16, bias=False + ) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Use NVFP4 quantizer factory + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) + + # Execute with custom recipe + with autocast(enabled=True, recipe=custom_recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Basic sanity: gradients exist + assert inp.grad is not None + + +@pytest.mark.parametrize("module_type", ["Linear", "LayerNormLinear", "OpsLinear", "LayerNormMLP"]) +def test_custom_recipe_sanity(module_type): + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + # Simple linear layer with dims divisible by 16 + in_features = 64 + out_features = 64 + batch = 32 + + if module_type == "Linear": + model = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + elif module_type == "LayerNormLinear": + model = LayerNormLinear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + elif module_type == "LayerNormMLP": + # hidden_size == in_features == out_features for simplicity + model = LayerNormMLP( + hidden_size=in_features, ffn_hidden_size=out_features, params_dtype=torch.bfloat16 + ).cuda() + else: + # OpsLinear path + model = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Single factory: map roles to quantizers + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + # Execute with custom recipe + with autocast(enabled=True, recipe=custom_recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Basic sanity: gradients exist + assert inp.grad is not None + + +def test_custom_recipe_grouped_linear_sanity(): + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + num_gemms = 3 + in_features = 64 + out_features = 64 + batch = 32 + base = batch // num_gemms + rem = batch % num_gemms + m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)] + + model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda() + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + with autocast(enabled=True, recipe=custom_recipe): + out = model(inp, m_splits) + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None + + +def test_custom_recipe_matches_current_scaling(): + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(123) + + in_features = 64 + out_features = 64 + batch = 32 + + # Create two identical models + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + model_custom = Linear(in_features, out_features, params_dtype=torch.bfloat16).cuda() + model_custom.load_state_dict(model_ref.state_dict()) + + # Identical inputs for both paths + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_custom = base_inp.clone().detach().requires_grad_(True) + + # Reference: use Float8CurrentScaling recipe + ref_recipe = recipe.Float8CurrentScaling() + with autocast(enabled=True, recipe=ref_recipe): + out_ref = model_ref(inp_ref) + # Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd) + ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3 + assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3 + assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3 + assert ref_bwd_go.dtype == tex.DType.kFloat8E5M2 + assert ref_bwd_gi.dtype == tex.DType.kFloat8E5M2 + + # Stress dynamic range in grad_output + scale = torch.ones(out_features, device="cuda", dtype=torch.float32) + scale[0] = 1e8 + scale[1] = 1e-8 + loss_ref = (out_ref.float() * scale.view(1, -1)).sum() + loss_ref.backward() + + # Custom: single factory returning quantizers per role to match Float8CurrentScaling + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=quantizer_factory) + + with autocast(enabled=True, recipe=custom_recipe): + out_custom = model_custom(inp_custom) + # Assert dtypes for custom quantizers match reference mapping + cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] + cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT] + cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1] + cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1] + assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3 + assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 + assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 + assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + + loss_custom = (out_custom.float() * scale.view(1, -1)).sum() + loss_custom.backward() + + # Compare forward outputs (exact match expected) + assert torch.allclose(out_ref, out_custom, rtol=0.0, atol=0.0) + + # Compare input gradients + assert inp_ref.grad is not None and inp_custom.grad is not None + assert torch.allclose(inp_ref.grad, inp_custom.grad, rtol=0.0, atol=0.0) + + # Compare parameter gradients (weights and bias if present) + ref_params = dict(model_ref.named_parameters()) + custom_params = dict(model_custom.named_parameters()) + for name, p_ref in ref_params.items(): + p_cus = custom_params[name] + assert p_ref.grad is not None and p_cus.grad is not None + assert torch.allclose(p_ref.grad, p_cus.grad, rtol=0.0, atol=0.0) + + +def test_custom_recipe_ops_linear_2_1_layout(): + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(7) + + in_features = 64 + out_features = 64 + batch = 16 + + # Use ops.Linear which consumes 2 forward quantizers and 1 backward quantizer + op = te_ops.Linear(in_features, out_features, device="cuda", dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + def quantizer_factory(role): + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom = recipe.CustomRecipe(qfactory=quantizer_factory) + + with autocast(enabled=True, recipe=custom): + out = op(inp) + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None + + +def test_custom_recipe_factory_invocation_counts_and_cycling(): + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(13) + + in_features = 64 + out_features = 64 + batch = 8 + + op = Linear(in_features, out_features, params_dtype=torch.bfloat16) + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + # Counters per role + counts = { + "linear_input": 0, + "linear_weight": 0, + "linear_output": 0, + "linear_grad_output": 0, + "linear_grad_input": 0, + } + + def quantizer_factory(role): + if role in counts: + counts[role] += 1 + if role in ("linear_input", "linear_weight", "linear_output"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + if role in ("linear_grad_output", "linear_grad_input"): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + + custom = recipe.CustomRecipe(qfactory=quantizer_factory) + + # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), + # and backward to build 2 quantizers (cycled from 1 factory). + with autocast(enabled=True, recipe=custom): + out = op(inp) + loss = out.float().sum() + loss.backward() + + # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input + assert counts["linear_input"] == 1 + assert counts["linear_weight"] == 1 + assert counts["linear_output"] == 1 + assert counts["linear_grad_output"] == 1 + assert counts["linear_grad_input"] == 1 + + +def test_factories_return_distinct_instances_and_buffers(): + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + # Two calls should produce distinct quantizer objects and distinct tensor buffers + def factory(): + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + + q1 = factory() + q2 = factory() + + assert q1 is not q2 + assert q1.scale.data_ptr() != q2.scale.data_ptr() + assert q1.amax.data_ptr() != q2.amax.data_ptr() + + # Mutating one should not affect the other + q1.scale.fill_(123.0) + assert not torch.equal(q1.scale, q2.scale) diff --git a/tests/pytorch/test_deferred_init.py b/tests/pytorch/test_deferred_init.py index 7d6d52362..4ce522495 100644 --- a/tests/pytorch/test_deferred_init.py +++ b/tests/pytorch/test_deferred_init.py @@ -4,7 +4,6 @@ import pytest import torch -import torch.distributed as dist import transformer_engine.pytorch as te diff --git a/tests/pytorch/test_float8_blockwise_gemm_exact.py b/tests/pytorch/test_float8_blockwise_gemm_exact.py index ec23cfe8c..9ae8a6069 100644 --- a/tests/pytorch/test_float8_blockwise_gemm_exact.py +++ b/tests/pytorch/test_float8_blockwise_gemm_exact.py @@ -4,22 +4,22 @@ import pytest import torch -import transformer_engine as te +import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( +from transformer_engine.pytorch import ( Float8BlockQuantizer, - Float8BlockwiseQTensor, + get_device_compute_capability, ) from references.blockwise_quantizer_reference import CuBLASScaleMunger from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm def fp8_blockwise_gemm_supported() -> bool: - supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() - return supported + supported = te.is_fp8_block_scaling_available() + emulated = get_device_compute_capability() >= (10, 0) + return supported and not emulated def cublas_gemm_fp8_blockwise_case( diff --git a/tests/pytorch/test_float8_blockwise_scaling_exact.py b/tests/pytorch/test_float8_blockwise_scaling_exact.py index 858ce73b6..153f0b7e0 100644 --- a/tests/pytorch/test_float8_blockwise_scaling_exact.py +++ b/tests/pytorch/test_float8_blockwise_scaling_exact.py @@ -8,14 +8,12 @@ import pathlib import pytest import torch -import transformer_engine as te -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +import transformer_engine.pytorch as te from transformer_engine.common.recipe import Float8BlockScaling from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( +from transformer_engine.pytorch import ( Float8BlockQuantizer, - Float8BlockwiseQTensor, + get_device_compute_capability, ) from references.blockwise_quantizer_reference import ( BlockwiseQuantizerReference, @@ -31,7 +29,8 @@ tensor_dump_dir_env = os.getenv("NVTE_TEST_BLOCK_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR") if tensor_dump_dir_env is not None: TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env) -recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available() +recipe_available, reason_for_no_recipe = te.is_fp8_block_scaling_available(return_reason=True) +recipe_emulated = get_device_compute_capability() >= (10, 0) class GetRecipes: @@ -218,6 +217,12 @@ def check_quantization_block_tiling_versus_reference( pow_2_scales: bool, tile_size: Tuple[int, int], ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + te_dtype = TE_DType[quant_dtype] if tile_size == (1, 128): block_scaling_dim = 1 @@ -409,6 +414,12 @@ def test_quantization_block_tiling_extrema_versus_reference( tile_size: Tuple[int, int], extrema_high: bool, ) -> None: + if recipe_emulated and not pow_2_scales: + pytest.skip( + "On Blackwell and newer, the FP8 block scaling recipe is emulated " + "with MXFP8, which requires using power of two scaling factors." + ) + # This test runs a single tile through a quantizer as a way to test # branch coverage of scale computation. te_dtype = TE_DType[quant_dtype] diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 281fc67a5..21fe4700b 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,12 +10,9 @@ import pytest import transformer_engine.pytorch as te -import transformer_engine_torch as tex -import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.common.recipe import Float8CurrentScaling, Format -from transformer_engine.pytorch.fp8 import fp8_autocast, get_fp8_torch_dtype +from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype # read env variable NVTE_TEST_FLOAT8_CURRENT_SCALING_EXACT_TENSOR_DUMP_DIR to override the default tensor dump directory @@ -24,7 +23,7 @@ # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) class GetRecipes: @@ -273,6 +272,14 @@ def run_linear_multiple_steps( if bgrad_list is not None and bgrad is not None: bgrad_list.append(bgrad.detach().clone()) + # Stack the results + return ( + torch.stack(y_q_list), + torch.stack(dgrad_list), + torch.stack(wgrad_list), + torch.stack(bgrad_list) if bgrad_list is not None else None, + ) + @classmethod def run_linear( cls, @@ -387,7 +394,7 @@ def compare_recipe( # recipe1 using_fp8_recipe = recipe1() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + with autocast(enabled=True, recipe=recipe1()): y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) else: y_q_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_linear(x, w, bias, gradient) @@ -395,7 +402,7 @@ def compare_recipe( # recipe2 using_fp8_recipe = recipe2() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + with autocast(enabled=True, recipe=recipe2()): y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) else: y_q, dgrad, wgrad, bgrad = self.run_linear(x, w, bias, gradient) @@ -610,7 +617,7 @@ def compare_recipe( # recipe1 using_fp8_recipe = recipe1() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe1()): + with autocast(enabled=True, recipe=recipe1()): y_q_ref, ln_out_ref, dgrad_ref, wgrad_ref, bgrad_ref = self.run_layernorm_linear( x, w, @@ -632,7 +639,7 @@ def compare_recipe( # recipe2 using_fp8_recipe = recipe2() != GetRecipes.none() if using_fp8_recipe: - with fp8_autocast(enabled=True, fp8_recipe=recipe2()): + with autocast(enabled=True, recipe=recipe2()): y_q, ln_out, dgrad, wgrad, bgrad = self.run_layernorm_linear( x, w, @@ -842,7 +849,7 @@ def test_fp8_current_scaling_linear_large_numel_e4m3(self, dtype, shape): pytest.skip(f"Skipping {shape}: insufficient device memory for allocation.") try: - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with autocast(enabled=True, recipe=recipe): y = layer(x) except torch.OutOfMemoryError: pytest.skip(f"Skipping {shape}: OOM during forward.") diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index c5bc7180a..b3346107a 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -14,12 +14,11 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine.common.recipe -import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( +from transformer_engine.pytorch import ( Float8BlockQuantizer, Float8BlockwiseQTensor, + get_device_compute_capability, ) -from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex # PyTorch tensor dtypes diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index a5c97a950..b7ddf0e8a 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -11,13 +11,11 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer, ) -from transformer_engine.pytorch.constants import TE_DType, TE_DType_To_Torch from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported import transformer_engine_torch as tex @@ -47,7 +45,7 @@ def _to_list(x: Union[Iterable, Any]) -> List: DimsType = Union[Iterable[int], int] # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) # delayed scaling diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 0bd11d941..5526103d5 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -14,15 +14,12 @@ from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling -from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention -from transformer_engine.pytorch import fp8_model_init -from transformer_engine.pytorch.utils import is_bf16_compatible -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch import MultiheadAttention, quantized_model_init, is_bf16_available from transformer_engine.pytorch.utils import gpu_autocast_ctx from transformer_engine.pytorch.utils import get_device_compute_capability # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) class TestFusedOptimizer: @@ -192,7 +189,7 @@ def gen_precision_aware_test( build_model_context = nullcontext build_model_context_args = {} if use_fp8_params: - build_model_context = fp8_model_init + build_model_context = quantized_model_init build_model_context_args["enabled"] = True with build_model_context(**build_model_context_args): @@ -290,7 +287,7 @@ def test_fp32_no_master(self): exp_avg_sq_dtype=torch.float32, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp32_master(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -302,7 +299,7 @@ def test_fp32_master(self): exp_avg_sq_dtype=torch.float32, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp32_master_store_param_remainders(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -315,7 +312,7 @@ def test_fp32_master_store_param_remainders(self): store_param_remainders=True, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp16_master(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -329,7 +326,7 @@ def test_fp16_master(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_grad(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -343,7 +340,7 @@ def test_bf16_grad(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp16_exp_avg(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -357,7 +354,7 @@ def test_fp16_exp_avg(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_exp_avg(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -371,7 +368,7 @@ def test_bf16_exp_avg(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): model_tol = 3e-2 if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) else None @@ -389,7 +386,7 @@ def test_fp8_exp_avg(self): model_atol=model_tol, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_fp16_exp_avg_sq(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -403,7 +400,7 @@ def test_fp16_exp_avg_sq(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_exp_avg_sq(self): self.gen_precision_aware_test( use_fp8_params=False, @@ -417,7 +414,7 @@ def test_bf16_exp_avg_sq(self): master_atol=2e-3, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg_sq(self): self.gen_precision_aware_test( @@ -431,7 +428,7 @@ def test_fp8_exp_avg_sq(self): skip_assert=True, ) - @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") + @pytest.mark.skipif(not is_bf16_available(), reason="bf16 if not supported") def test_bf16_model_weight_cast(self): dtype = torch.bfloat16 model = MultiheadAttention( @@ -475,7 +472,7 @@ def test_bf16_model_weight_cast(self): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_model_weight_cast(self): dtype = torch.bfloat16 - with fp8_model_init(enabled=True, recipe=DelayedScaling()): + with quantized_model_init(enabled=True, recipe=DelayedScaling()): model = MultiheadAttention( hidden_size=1024, num_attention_heads=16, diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index 62d80b552..9e4ddbdad 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -58,10 +58,6 @@ def test_fused_rope( # are with the maximum length of the rope embeddings. pytest.skip("Skipping test with margin=0 and start_positions=True") - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( @@ -102,11 +98,8 @@ def test_fused_rope( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -121,17 +114,12 @@ def test_fused_rope( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) - - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - + torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @@ -156,10 +144,6 @@ def test_fused_rope_thd( margin: int, ) -> None: - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] @@ -214,10 +198,8 @@ def test_fused_rope_thd( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -233,18 +215,142 @@ def test_fused_rope_thd( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) + assert output_fused.is_contiguous() - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - assert output_fused.is_contiguous() +@pytest.mark.parametrize("start_positions", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("rotary_percent", [1.0]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [2]) +@pytest.mark.parametrize("interleaved", [False, True]) +def test_unfused_rope_thd_vs_bshd( + dtype: torch.dtype, + hidden_size: int, + rotary_percent: float, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + """ + This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD + formats are the same. + """ + device = torch.device("cuda:0") + seqlen, max_seqlen = 16, 2048 + batch_size, head_num = 4, 256 + + # NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and + # that causes unexpected issues. + seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32) + + cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to( + device=device, dtype=torch.int32 + ) + + # Create a tensor in THD format + thd = torch.rand( + (cu_seqlens[-1] // cp_size, head_num, hidden_size), + dtype=dtype, + device=device, + ) + thd.requires_grad = True + + # Clone the tensor to create a tensor in BSHD format + bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach() + bshd = bshd.to(dtype=dtype, device=device) + bshd.requires_grad = True + + # Clone the tensor to create a tensor in SBHD format + sbhd = bshd.transpose(1, 0).clone().detach() + sbhd = sbhd.to(dtype=dtype, device=device) + sbhd.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb = rotary_pos_emb(max_seqlen) + assert emb.is_contiguous() + + start_positions = cu_seqlens[:-1] if start_positions else None + + for cp_rank in range(cp_size): + # unfused bshd + output_unfused_bshd = apply_rotary_pos_emb( + bshd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="bshd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + loss_unfused_bshd = loss_func(output_unfused_bshd) + loss_unfused_bshd.backward() + grad_unfused_bshd = bshd.grad.detach().clone() + bshd.grad = None + + # unfused sbhd + output_unfused_sbhd = apply_rotary_pos_emb( + sbhd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="sbhd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_sbhd = loss_func(output_unfused_sbhd) + loss_unfused_sbhd.backward() + grad_unfused_sbhd = sbhd.grad.detach().clone() + sbhd.grad = None + + # unfused thd + output_unfused_thd = apply_rotary_pos_emb( + thd.float(), + emb, + start_positions=start_positions, + tensor_format="thd", + interleaved=interleaved, + fused=False, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_thd = loss_func(output_unfused_thd) + loss_unfused_thd.backward() + grad_unfused_thd = thd.grad.detach().clone() + thd.grad = None + + torch.testing.assert_close( + output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd + ) + torch.testing.assert_close( + output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape), + output_unfused_thd, + ) + torch.testing.assert_close( + grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + torch.testing.assert_close( + grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + + assert output_unfused_thd.is_contiguous() + assert output_unfused_bshd.is_contiguous() + assert output_unfused_sbhd.is_contiguous() @pytest.mark.parametrize("start_positions", [True, False]) @@ -373,3 +479,19 @@ def test_fused_qkv_rope( if not isinstance(start_positions, torch.Tensor): torch.testing.assert_close(grad_fused, grad_unfused) + + +def test_rotary_position_embedding_forward_with_autocast_gives_same_result_as_without_autocast(): + rope_layer = RotaryPositionEmbedding(128) + + rope_embeddings_no_autocast = rope_layer(max_seq_len=1024) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + rope_embeddings_autocast = rope_layer(max_seq_len=1024) + + torch.testing.assert_close( + rope_embeddings_no_autocast.to(dtype=torch.bfloat16), + rope_embeddings_autocast.to(dtype=torch.bfloat16), + atol=1e-8, + rtol=1e-8, + ) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 500b25f58..a67fd4f45 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -9,8 +9,6 @@ from collections.abc import Iterable import io import math -import pathlib -import sys from typing import Optional import pytest @@ -19,7 +17,6 @@ import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, @@ -30,28 +27,29 @@ ForwardLinearBiasAdd, ForwardLinearScaleAdd, ) -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8Tensor, +from transformer_engine.pytorch import ( + QuantizedTensor, Float8CurrentScalingQuantizer, Float8Quantizer, + MXFP8Quantizer, + NVFP4Quantizer, + is_bf16_available, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer -from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION # Import utility functions -from utils import dtype_tols, make_recipe, reset_rng_states +from utils import dtype_tols, make_recipe, quantization_tols, reset_rng_states -# Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +# Check for supported quantization schemes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) # Supported data types _dtypes: list[torch.dtype] = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher _dtypes.append(torch.bfloat16) # Supported devices @@ -63,6 +61,8 @@ _quantization_list.extend(("fp8_delayed_scaling", "fp8_current_scaling")) if mxfp8_available: _quantization_list.append("mxfp8") +if nvfp4_available: + _quantization_list.append("nvfp4") def maybe_skip_quantization( @@ -70,6 +70,7 @@ def maybe_skip_quantization( *, dims: Optional[Iterable[int] | int] = None, device: Optional[torch.device | str] = None, + dtype: Optional[torch.dtype] = None, ) -> None: """Skip test case if a quantization scheme is not supported""" @@ -77,12 +78,17 @@ def maybe_skip_quantization( if quantization is None: return - # Check if quantization scheme is supported + # Check if quantization scheme is supported on device + if device is not None and torch.device(device).type != "cuda": + pytest.skip("Quantization is only supported on CUDA devices") if quantization in ("fp8", "fp8_delayed_scaling", "fp8_current_scaling") and not fp8_available: pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if quantization == "nvfp4" and not nvfp4_available: + pytest.skip(reason_for_no_nvfp4) + # Check dims if dims is not None: if not isinstance(dims, Iterable): dims = (dims,) @@ -92,10 +98,14 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") + elif quantization == "nvfp4": + if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: + pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") - # Check if device is supported - if device is not None and torch.device(device).type != "cuda": - pytest.skip("Quantization is only supported on CUDA devices") + # Check dtype + if dtype is not None: + if quantization == "nvfp4" and dtype != torch.bfloat16: + pytest.skip("NVFP4 quantization is only supported with BF16 data") @torch.no_grad() @@ -145,6 +155,14 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) + elif quantization == "nvfp4": + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -354,7 +372,7 @@ def test_fp8_scale_update( ) # Construct model - with te.fp8_model_init(recipe=recipe): + with te.quantized_model_init(recipe=recipe): model = te_ops.basic.BasicLinear( size, size, @@ -386,7 +404,7 @@ def test_fp8_scale_update( ) # Training step - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): y = model(x) y.backward(dy) with torch.no_grad(): @@ -399,12 +417,12 @@ def test_fp8_scale_update( torch.testing.assert_close( y, torch.full_like(y, y_val_ref), - **dtype_tols(tex.DType.kFloat8E4M3), + **quantization_tols("fp8_delayed_scaling"), ) torch.testing.assert_close( x.grad, torch.full_like(x.grad, dx_val_ref), - **dtype_tols(tex.DType.kFloat8E5M2), + **quantization_tols("fp8_delayed_scaling"), ) # Check that scaling factors match expected @@ -438,7 +456,8 @@ def test_dtype_cast( # Skip invalid configurations in_shape = (size, size) with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=init_dtype) + maybe_skip_quantization(quantization, dtype=final_dtype) # Random data dtype = torch.float32 @@ -454,7 +473,7 @@ def test_dtype_cast( ) # Construct operation - with te.fp8_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): + with te.quantized_model_init(enabled=with_quantization, recipe=make_recipe(quantization)): op = te_ops.Linear(size, size, bias=False, device=device, dtype=init_dtype) with torch.no_grad(): op.weight.copy_(w_test) @@ -506,11 +525,12 @@ def test_pyt_autocast( # Skip invalid configurations in_shape = (size, size) quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=model_dtype) + maybe_skip_quantization(quantization, dtype=autocast_dtype) # Construct operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weights, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weights, recipe=recipe): op = te_ops.Linear(size, size, bias=False, device=device, dtype=model_dtype) # Check forward and backward pass @@ -520,7 +540,7 @@ def test_pyt_autocast( device=device, requires_grad=True, ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): with torch.autocast(device_type=device.type, dtype=autocast_dtype): y = op(x) y.backward(torch.zeros_like(y)) @@ -533,7 +553,7 @@ def test_pyt_autocast( x.grad = None op.weight.grad = None with torch.autocast(device_type=device.type, dtype=autocast_dtype): - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = op(x) y.backward(torch.zeros_like(y)) assert y.dtype == autocast_dtype @@ -562,7 +582,7 @@ def test_identity( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -628,7 +648,7 @@ def test_reshape( # Skip invalid configurations if memory_format == torch.channels_last and len(in_shape) != 4: pytest.skip("torch.channels_last only supports 4D tensors") - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) with_quantization = quantization is not None # Random data @@ -694,7 +714,7 @@ def test_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -756,7 +776,7 @@ def test_quantize( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8": maybe_skip_quantization(quantization, dims=in_shape) @@ -783,7 +803,7 @@ def test_quantize( # Implementation with fusible operation op = te_ops.Quantize(forward=cast_forward, backward=cast_backward) recipe = make_recipe(quantization) - with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): + with te.autocast(enabled=with_quantization, recipe=recipe): y_test = op(x_test) y_test.backward(dy_test) @@ -823,7 +843,7 @@ def _test_basic_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) quantization_needed = any( ( @@ -877,7 +897,7 @@ def _test_basic_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.BasicLinear( in_features, out_features, @@ -894,7 +914,7 @@ def _test_basic_linear( op, te_ops.Quantize(forward=quantized_output, backward=quantized_grad_output), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -903,7 +923,7 @@ def _test_basic_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute or quantized_output or quantized_grad_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1024,7 +1044,7 @@ def test_linear( out_shape = in_shape[:-1] + [out_features] # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantization is None and (quantized_compute or quantized_weight): pytest.skip("Quantization scheme is not specified") @@ -1065,7 +1085,7 @@ def test_linear( # Implementation with fusible operation recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): op = te_ops.Linear( in_features, out_features, @@ -1081,7 +1101,7 @@ def test_linear( del b_test for param in op.parameters(): param.requires_grad_(requires_grad=weight_requires_grad) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = op(x_test) if input_requires_grad or weight_requires_grad: y_test.backward(dy_test) @@ -1091,7 +1111,7 @@ def test_linear( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1128,7 +1148,7 @@ def test_layer_norm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1182,14 +1202,14 @@ def test_layer_norm( op, te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1298,7 +1318,7 @@ def test_rmsnorm( in_shape = list(in_shape)[:-1] + list(weight_shape) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1344,7 +1364,7 @@ def test_rmsnorm( op, te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) @@ -1352,16 +1372,8 @@ def test_rmsnorm( # Expected numerical error tols = dtype_tols(dtype) - # Explicit checks for quantization if quantized_compute: - tols = dtype_tols(y_test._quantizer.dtype) - expected_tensor_cls = { - Float8Quantizer:Float8Tensor, - Float8CurrentScalingQuantizer:Float8Tensor, - MXFP8Quantizer:MXFP8Tensor - }[type(y_test._quantizer)] - assert isinstance(y_test, expected_tensor_cls) - y_test = y_test.dequantize(dtype=torch.float32) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1441,7 +1453,7 @@ def test_add_extra_input( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x1_ref, x1_test = make_reference_and_test_tensors( @@ -1480,8 +1492,11 @@ def test_add_extra_input( # Check results tols = dtype_tols(dtype) - if with_quantization: - tols = dtype_tols(x1_test._fp8_dtype) + if in_place: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): + tols = dtype_tols(x1_test._fp8_dtype) + elif quantization == "nvfp4": + tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") @@ -1510,7 +1525,7 @@ def test_make_extra_output( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1583,7 +1598,7 @@ def test_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) if cache_quantized_input: maybe_skip_quantization("fp8_current_scaling", device=device) @@ -1651,14 +1666,16 @@ def test_activation( make_op(cache_quantized_input=cache_quantized_input), te_ops.Quantize(forward=quantized_compute, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute or cache_quantized_input: - tols = dtype_tols(tex.DType.kFloat8E4M3) + if quantized_compute: + tols = quantization_tols(quantization) + elif cache_quantized_input: + tols = quantization_tols("fp8_current_scaling") # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1689,7 +1706,7 @@ def test_swiglu( quantized_compute = quantization is not None if not quantized_compute and (quantize_forward or quantize_backward): pytest.skip("Quantization scheme has not been provided") - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) # Random data x_ref, x_test = make_reference_and_test_tensors( @@ -1716,13 +1733,87 @@ def test_swiglu( te_ops.SwiGLU(), te_ops.Quantize(forward=quantize_forward, backward=False), ) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) # Expected numerical error tols = dtype_tols(dtype) if quantized_compute: + tols = quantization_tols(quantization) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + @pytest.mark.parametrize("quantize_forward", (False, True)) + @pytest.mark.parametrize("quantize_backward", (False, True)) + def test_clamped_swiglu( + self, + *, + out_shape: Iterable[int] = (32, 32), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantize_forward: bool, + quantize_backward: bool, + limit: float = 0.75, + alpha: float = 1.702, + ): + # Test SwiGLU variant used in GPT OSS. + # Tensor dimensions + in_shape = list(out_shape) + in_shape[-1] *= 2 + + # Skip invalid configurations + quantized_compute = quantization is not None + if not quantized_compute and (quantize_forward or quantize_backward): + pytest.skip("Quantization scheme has not been provided") + maybe_skip_quantization(quantization, dims=in_shape, device=device) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + x_glu, x_linear = x_ref.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + y_ref = out_glu * (x_linear + 1) + y_ref.backward(dy_ref) + + # Implementation with fusible operation + recipe = make_recipe(quantization) + + forward = te_ops.Sequential( + te_ops.Quantize(forward=False, backward=quantize_backward), + te_ops.ClampedSwiGLU(limit=limit, alpha=alpha), + te_ops.Quantize(forward=quantize_forward, backward=False), + ) + with te.autocast(enabled=quantized_compute, recipe=recipe): + y_test = forward(x_test) + + y_test.backward(dy_test) + + # Expected numerical error + tols = dtype_tols(dtype) + if quantized_compute and quantization == "nvfp4": + tols = dtype_tols(tex.DType.kFloat4E2M1) + elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) # Check results @@ -1791,7 +1882,7 @@ def test_dropout( # Skip invalid configurations quantized_input = quantization is not None - maybe_skip_quantization(quantization, dims=shape, device=device) + maybe_skip_quantization(quantization, dims=shape, device=device, dtype=dtype) # Random data # Note: Shift values to make sure inputs are non-zero @@ -1882,7 +1973,7 @@ def test_forward_linear_bias_activation( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if dtype not in (torch.float16, torch.bfloat16): pytest.skip( @@ -1923,7 +2014,7 @@ def test_forward_linear_bias_activation( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_compute, recipe=recipe): + with te.quantized_model_init(enabled=quantized_compute, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -1939,7 +2030,7 @@ def test_forward_linear_bias_activation( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -1953,7 +2044,7 @@ def test_forward_linear_bias_activation( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -1989,7 +2080,7 @@ def test_forward_linear_bias_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2033,7 +2124,7 @@ def test_forward_linear_bias_add( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2050,7 +2141,7 @@ def test_forward_linear_bias_add( model[0].bias.copy_(b_test) del w_test del b_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -2064,7 +2155,7 @@ def test_forward_linear_bias_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2102,7 +2193,7 @@ def test_forward_linear_scale_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2139,7 +2230,7 @@ def test_forward_linear_scale_add( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2155,7 +2246,7 @@ def test_forward_linear_scale_add( with torch.no_grad(): model[0].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x1_test, x2_test) y_test.backward(dy_test) @@ -2170,7 +2261,7 @@ def test_forward_linear_scale_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2203,7 +2294,7 @@ def test_backward_activation_bias( # Skip invalid configurations with_quantization = quantization is not None - maybe_skip_quantization(quantization, device=device) + maybe_skip_quantization(quantization, device=device, dtype=dtype) if quantization == "mxfp8" and (len(in_shape) < 2 or in_shape[-1] % 32 != 0): pytest.skip("Unsupported tensor size for MXFP8") @@ -2246,7 +2337,7 @@ def test_backward_activation_bias( with torch.no_grad(): model[1].bias.copy_(b_test) del b_test - with te.fp8_autocast(enabled=with_quantization, fp8_recipe=recipe): + with te.autocast(enabled=with_quantization, recipe=recipe): y_test = model(x_test) y_test.backward(dy_test) @@ -2265,7 +2356,7 @@ def test_backward_activation_bias( # Expected numerical error tols = dtype_tols(dtype) if with_quantization: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2384,7 +2475,7 @@ def test_backward_linear_add( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2424,7 +2515,7 @@ def test_backward_linear_add( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight): + with te.quantized_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.MakeExtraOutput(in_place=True), te_ops.Linear( @@ -2438,7 +2529,7 @@ def test_backward_linear_add( with torch.no_grad(): model[1].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y1_test, y2_test = model(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -2452,7 +2543,7 @@ def test_backward_linear_add( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y1_test = y1_test.to(dtype=torch.float64, device="cpu") @@ -2487,7 +2578,7 @@ def test_backward_linear_scale( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) if quantized_compute and dtype not in (torch.float16, torch.bfloat16): pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") @@ -2519,7 +2610,7 @@ def test_backward_linear_scale( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight): + with te.quantized_model_init(enabled=quantized_weight): model = te_ops.Sequential( te_ops.Linear( in_features, @@ -2533,7 +2624,7 @@ def test_backward_linear_scale( with torch.no_grad(): model[0].weight.copy_(w_test) del w_test - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = model(x_test) (y_test * dy_test).sum().backward() @@ -2547,7 +2638,7 @@ def test_backward_linear_scale( if dtype == torch.float32: tols = dtype_tols(torch.float16) # TF32 GEMM if quantized_compute: - tols = dtype_tols(tex.DType.kFloat8E4M3) + tols = quantization_tols(quantization) # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") @@ -2588,12 +2679,12 @@ def test_linear( # Skip invalid configurations quantized_compute = quantization is not None - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=out_shape) # Construct model recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model_save = te_ops.Sequential( te_ops.Linear(in_features, out_features, device=device, dtype=dtype) ) @@ -2604,7 +2695,7 @@ def test_linear( x = torch.randn(in_shape, dtype=dtype, device=device, requires_grad=True) dy = torch.randn(out_shape, dtype=dtype, device=device) optim_save.zero_grad() - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = model_save(x) y.backward(dy) optim_save.step() @@ -2633,14 +2724,14 @@ def test_linear( ys_save = [] for i in range(post_checkpoint_steps): optim_save.zero_grad() - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = model_save(xs_save[i]) y.backward(dys[i]) optim_save.step() ys_save.append(y) # Load checkpoint - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): model_load = te_ops.Sequential( te_ops.Linear(in_features, out_features, device=device, dtype=dtype) ) @@ -2653,7 +2744,7 @@ def test_linear( ys_load = [] for i in range(post_checkpoint_steps): optim_load.zero_grad() - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y = model_load(xs_load[i]) y.backward(dys[i]) optim_load.step() @@ -2714,7 +2805,7 @@ def test_layernorm_mlp( ffn_shape = in_shape[:-1] + (ffn_hidden_size,) # Skip invalid configurations - maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=in_shape, device=device, dtype=dtype) maybe_skip_quantization(quantization, dims=ffn_shape, device=device) quantization_needed = quantized_compute or quantized_weight if quantization is None and quantization_needed: @@ -2740,7 +2831,7 @@ def test_layernorm_mlp( # Implementation with fusible operations recipe = make_recipe(quantization) - with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): if normalization == "LayerNorm": norm = te_ops.LayerNorm( hidden_size, @@ -2771,6 +2862,6 @@ def test_layernorm_mlp( dtype=dtype, ) forward = te_ops.Sequential(norm, ffn1, act, ffn2) - with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + with te.autocast(enabled=quantized_compute, recipe=recipe): y_test = forward(x_test) y_test.backward(dy_test) diff --git a/tests/pytorch/test_hf_integration.py b/tests/pytorch/test_hf_integration.py index e74b16022..b014201c2 100644 --- a/tests/pytorch/test_hf_integration.py +++ b/tests/pytorch/test_hf_integration.py @@ -6,7 +6,7 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_utils import PreTrainedModel -from transformer_engine.pytorch.transformer import TransformerLayer +from transformer_engine.pytorch import TransformerLayer class SimpleTEModel(PreTrainedModel): diff --git a/tests/pytorch/test_layernorm_saved_tensors_logic.py b/tests/pytorch/test_layernorm_saved_tensors_logic.py index cb7760b5f..ab7d5d9d4 100644 --- a/tests/pytorch/test_layernorm_saved_tensors_logic.py +++ b/tests/pytorch/test_layernorm_saved_tensors_logic.py @@ -1,11 +1,11 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import pytest import torch import torch.nn as nn from unittest.mock import patch -from transformer_engine.pytorch import LayerNormLinear, LayerNormMLP, fp8_autocast +from transformer_engine.pytorch import LayerNormLinear, LayerNormMLP, autocast from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.fp8 import FP8GlobalStateManager @@ -84,7 +84,7 @@ def spy_on_ctx(ctx, *args, **kwargs): weight_tensor.requires_grad_(True) with patch(config["backward_target"], side_effect=spy_on_ctx) as mock_backward: - with fp8_autocast(enabled=True): + with autocast(enabled=True): out, ln_out_returned = model(inp) out.backward(grad_output, retain_graph=True) @@ -99,7 +99,7 @@ def spy_on_ctx(ctx, *args, **kwargs): saved_ln_out_container.clear() with patch(config["backward_target"], side_effect=spy_on_ctx) as mock_backward: - with fp8_autocast(enabled=True): + with autocast(enabled=True): out, ln_out_returned = model(inp) out.backward(grad_output) diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py index cc6f959ab..340d8e4d2 100644 --- a/tests/pytorch/test_multi_tensor.py +++ b/tests/pytorch/test_multi_tensor.py @@ -5,7 +5,7 @@ import pytest import torch -import transformer_engine.pytorch as te +import transformer_engine.pytorch import transformer_engine_torch as tex from transformer_engine.pytorch.optimizers import MultiTensorApply from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a4dfd64ba..9fe6304d4 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -15,21 +15,18 @@ from torch.nn import Parameter from torch.utils.cpp_extension import IS_HIP_EXTENSION -from transformer_engine.pytorch.fp8 import ( - FP8GlobalStateManager, - fp8_autocast, - fp8_model_init, -) +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, attention_mask_func, - is_bf16_compatible, ) if IS_HIP_EXTENSION: from transformer_engine.pytorch.utils import is_mi200, is_mi308 from transformer_engine.pytorch import ( + autocast, + quantized_model_init, DotProductAttention, LayerNormLinear, LayerNormMLP, @@ -41,29 +38,31 @@ LayerNorm, Fp8Padding, Fp8Unpadding, -) -from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils -from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, Float8CurrentScalingQuantizer, + MXFP8Quantizer, + get_device_compute_capability, + is_fp8_available, + is_mxfp8_available, + is_fp8_block_scaling_available, + is_bf16_available, + is_nvfp4_available, ) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils as fa_utils +from transformer_engine.pytorch import checkpoint as te_checkpoint +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace -from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex -from utils import ModelConfig, reset_rng_states, get_available_attention_backends +from utils import ModelConfig, reset_rng_states if IS_HIP_EXTENSION: from utils import EnvVarCleaner - # Only run FP8 tests on supported devices. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) +fp8_block_scaling_available = is_fp8_block_scaling_available() +nvfp4_available = is_nvfp4_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -97,7 +96,7 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: input_formats_inference = ["sbhd", "bshd"] param_types = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) batch_sizes = [1, 2] @@ -136,6 +135,43 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: ) +def nvfp4_rht_and_2d_quantization(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams( + random_hadamard_transform=False, fp4_2d_quantization=True + ) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams( + random_hadamard_transform=True, fp4_2d_quantization=False + ) + return nvfp4_recipe + + +def check_rht_usage(recipe: recipe.Recipe) -> bool: + # if using RHT, we can only support bf16 + # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad + if recipe.nvfp4(): + if ( + recipe.fp4_quant_fwd_inp.random_hadamard_transform + or recipe.fp4_quant_fwd_weight.random_hadamard_transform + or recipe.fp4_quant_bwd_grad.random_hadamard_transform + ): + return True + return False + + +def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> bool: + supported_input_dtypes = [] + if recipe.nvfp4(): + supported_input_dtypes.append(torch.bfloat16) + # if not using RHT, we can add fp32 as well + if not check_rht_usage(recipe): + supported_input_dtypes.append(torch.float32) + return supported_input_dtypes + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) @@ -144,6 +180,8 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: if fp8_available: fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +if nvfp4_available: + fp8_recipes.append(nvfp4_rht_and_2d_quantization()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper @@ -151,25 +189,6 @@ def rocm_attn_backend() -> tuple[bool, bool, bool]: use_cutlass_grouped_gemm.append(True) -def is_fused_attn_available( - config: ModelConfig, - dtype: torch.dtype, - qkv_layout="bshd_bshd_bshd", - is_training=True, - deterministic=False, -): - _, _, fused_attn_backends = get_available_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - is_training=is_training, - deterministic=deterministic, - ) - if IS_HIP_EXTENSION: - return fused_attn_backends != [] - return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends - - def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -583,7 +602,7 @@ def _test_e2e_selective_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -610,7 +629,7 @@ def _test_e2e_selective_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -636,6 +655,11 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) config = model_configs[model] @@ -673,7 +697,7 @@ def _test_e2e_full_recompute( init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -701,7 +725,7 @@ def _test_e2e_full_recompute( te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): if recompute: te_out = te_checkpoint( block, @@ -754,6 +778,11 @@ def test_gpt_full_activation_recompute( and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") + if fp8 and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow @@ -905,9 +934,6 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] - if not is_fused_attn_available(config, dtype, deterministic=True): - pytest.skip("No attention backend available.") - outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -960,10 +986,6 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] - if not is_fused_attn_available( - config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True - ): - pytest.skip("No attention backend available.") te_gpt = TransformerLayer( hidden_size=config.hidden_size, @@ -1075,10 +1097,6 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] - if not is_fused_attn_available( - config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True - ): - pytest.skip("No attention backend available.") te_mha = MultiheadAttention( config.hidden_size, @@ -1146,7 +1164,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, ) inp_hidden_states.retain_grad() - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): out = block(inp_hidden_states) if isinstance(out, (List, Tuple)): out = out[0] @@ -1178,7 +1196,7 @@ def _test_granular_accuracy_with_fp8(block, bs, dtype, config): ) inp_hidden_states.retain_grad() - with fp8_autocast(enabled=True): + with autocast(enabled=True): out = block(inp_hidden_states) loss = out.sum() loss.backward() @@ -1339,10 +1357,11 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model module = LayerNormLinear config = model_configs[model] - with fp8_model_init(enabled=fp8_model_params): + with quantized_model_init(enabled=fp8_model_params): layer = module( config.hidden_size, 4 * config.hidden_size, + config.eps, bias=True, params_dtype=dtype, device="cuda", @@ -1353,6 +1372,7 @@ def test_fp8_linear_without_transpose_cache_accuracy(dtype, bs, model, fp8_model ref_layer = module( config.hidden_size, 4 * config.hidden_size, + config.eps, bias=True, params_dtype=dtype, device="cuda", @@ -1443,7 +1463,13 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): te_linear_ref = Linear( config.hidden_size, 4 * config.hidden_size, @@ -1967,8 +1993,8 @@ def _test_grouped_linear_accuracy( split_size = 1 if fp8: split_size = 16 - if recipe.mxfp8(): - split_size = 128 + if recipe.mxfp8() or recipe.nvfp4(): + split_size = 32 m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero @@ -1978,7 +2004,7 @@ def _test_grouped_linear_accuracy( else: m_splits = torch.tensor([config.max_seqlen_q]) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): if isinstance(block, GroupedLinear): m_splits = m_splits * bs out = block(inp_hidden_states, m_splits.tolist()) @@ -2044,7 +2070,13 @@ def test_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2180,7 +2212,13 @@ def test_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2267,7 +2305,7 @@ def _test_padding_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, r def _pad_tensor_for_fp8(hidden_states, tokens_per_expert): align_size = 16 - if recipe.mxfp8(): + if recipe.mxfp8() or recipe.nvfp4(): align_size = 32 padded_tokens_per_expert = [ (num_tokens + align_size - 1) // align_size * align_size @@ -2334,7 +2372,7 @@ def _generate_random_numbers(n, total_sum): m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) - with fp8_autocast(enabled=fp8, fp8_recipe=recipe): + with autocast(enabled=fp8, recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): out = block(inp_hidden_states, m_splits) else: @@ -2382,7 +2420,13 @@ def test_padding_grouped_linear_accuracy( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -2393,7 +2437,7 @@ def test_padding_grouped_linear_accuracy( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2453,7 +2497,13 @@ def test_padding_grouped_linear_accuracy_save_original_input( if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + if recipe is not None and recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): grouped_linear = TorchGroupedLinearWithPadding( num_gemms, config.hidden_size, @@ -2464,7 +2514,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( fp8=fp8, ).eval() - with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): ref_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, @@ -2620,7 +2670,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - with fp8_model_init(enabled=fp8_model_params, recipe=recipe): + with quantized_model_init(enabled=fp8_model_params, recipe=recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, @@ -2647,7 +2697,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): te_inp_hidden_states.retain_grad() te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with autocast(enabled=True, recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -2668,6 +2718,12 @@ def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if recipe.nvfp4(): + if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): + pytest.skip( + f"Input dtype {dtype} not supported for NVFP4 Recipe {recipe.__class__.__name__}" + ) + config = model_configs[model] outputs = _test_gpt_fp8_parameters(bs, dtype, config, False, recipe) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index e5368497d..2ce6eb82b 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -34,7 +34,7 @@ from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import get_default_init_method import tensorrt as trt @@ -57,8 +57,8 @@ # The directory where this file is stored. TESTS_DIR = os.path.dirname(os.path.abspath(__file__)) -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) fp8_recipes = [] if mxfp8_available: @@ -68,7 +68,7 @@ fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(None) -supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] +supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "clamped_swiglu"] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -178,8 +178,8 @@ def do_export( input_names = input_names or ["input"] output_names = output_names or ["output"] - with torch.inference_mode(), te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + with torch.inference_mode(), te.autocast( + enabled=fp8_recipe is not None, recipe=fp8_recipe ), warnings.catch_warnings(): warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*") @@ -233,8 +233,8 @@ def te_infer( fp8_recipe: recipe.Recipe, ): """Transformer Engine forward propagation.""" - with torch.inference_mode(), te.fp8_autocast( - enabled=is_fp8, fp8_recipe=fp8_recipe + with torch.inference_mode(), te.autocast( + enabled=is_fp8, recipe=fp8_recipe ), warnings.catch_warnings(): te_outputs = model(*inps if isinstance(inps, tuple) else (inps,)) if not isinstance(te_outputs, tuple): @@ -440,7 +440,7 @@ def forward(self, inp): bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx" - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to( device="cuda" ) @@ -500,7 +500,7 @@ def _test_export_layernorm( fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx" with torch.no_grad(): - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm model = layernorm_cls( hidden_size, @@ -568,7 +568,7 @@ def _test_export_layernorm_linear( fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx" with torch.no_grad(): - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): model = te.LayerNormLinear( hidden_size, 3 * hidden_size, @@ -654,7 +654,7 @@ def _test_export_layernorm_mlp( bias_str = "_bias" if use_bias else "" high_prec_str = dtype2str(precision) fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx" - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): model = te.LayerNormMLP( hidden_size, ffn_hidden_size, @@ -1160,13 +1160,13 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): out_ref = model(*inps) onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") os.close(onnx_fd) try: - with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe): with te.onnx_export(enabled=True): torch.onnx.export( model, diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index fa56852ff..e325146b7 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -4,7 +4,7 @@ import random import torch -from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch import parallel_cross_entropy from utils import dtype_tols diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 5202155e2..ad4d6b622 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -10,6 +10,7 @@ import pytest from typing import Dict, List +import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch import ( moe_permute as te_permute, @@ -18,14 +19,12 @@ moe_sort_chunks_by_index as te_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs as te_sort_chunks_by_index_with_probs, ) -from transformer_engine.pytorch.utils import is_bf16_compatible -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.tensor.float8_tensor import ( +from transformer_engine.pytorch import ( Float8Quantizer, Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, ) -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer import transformer_engine_torch as tex import copy @@ -1125,7 +1124,7 @@ def perf_test_cuda_kernel(cuda_kernel_fn): # TE tensor dtypes _te_dtypes: List[tex.DType] = [tex.DType.kFloat32, tex.DType.kFloat16] -if is_bf16_compatible(): +if te.is_bf16_available(): _te_dtypes.append(tex.DType.kBFloat16) @@ -1245,10 +1244,10 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): # Only run FP8 tests on H100. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) fp8_recipes = [ recipe.MXFP8BlockScaling(), diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index c59bf376a..6850be9b4 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -1,10 +1,10 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from typing import Iterable, Optional +from typing import Optional import pytest import torch @@ -12,28 +12,35 @@ import transformer_engine.common.recipe import transformer_engine.pytorch as te -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch import ( + Float8BlockQuantizer, + MXFP8Quantizer, + Float8Quantizer, + NVFP4Quantizer, + quantized_model_init, + Linear, + LayerNormLinear, + LayerNormMLP, + GroupedLinear, +) + import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import ( +from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, _amax_and_scale_update, - fp8_model_init, ) -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.utils import is_fp8_fnuz -from transformer_engine.pytorch import Linear, LayerNormLinear, LayerNormMLP, GroupedLinear -from transformer_engine.pytorch.distributed import fp8_autocast from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling import transformer_engine_torch as tex # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True ) +fp4_available, reason_for_no_fp4 = te.is_nvfp4_available(return_reason=True) # FP8 per tensor delayed scaling @@ -66,7 +73,7 @@ def test_fp8_scale_update_with_linear_module( amax_history_len=amax_history_len, amax_compute_algo=amax_compute_algo, ) - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): module = te.Linear(16, 16) y = module( torch.randn([16, 16], device="cuda"), @@ -122,7 +129,7 @@ def test_fp8_scale_update_with_linear_module( # ref_scale_inv_backward = torch.reciprocal(ref_scale_backward) # Perform forward, backward, and optimizer steps to update fp8_meta - with te.fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.autocast(enabled=True, recipe=recipe): x = torch.randn([16, 16], device="cuda") y = module(x, is_first_microbatch=is_first_microbatch) y.backward(torch.randn_like(y)) @@ -221,7 +228,7 @@ def test_fp8_scale_update_with_linear_fuser_op( op.weight.fill_(w_history[-1]) # Forward and backward pass - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): y = op(x) y.backward(dy) @@ -303,7 +310,7 @@ def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype scaling_factor_compute_algo = None if fused_update: scaling_factor_compute_algo = ( - lambda amax, scale, fp8_max, recipe: te.fp8._default_sf_compute( + lambda amax, scale, fp8_max, recipe: te.quantization._default_sf_compute( amax, scale, fp8_max, recipe.margin ) ) @@ -313,7 +320,7 @@ def test_scale_update_numeric_scenarios(self, amax_case, fused_update, fp8_dtype # Setup fp8_meta dictionary def setup_fp8_meta(): - with te.fp8_autocast(fp8_recipe=recipe): + with te.autocast(recipe=recipe): module = te.Linear(16, 16) y = module(torch.zeros([16, 16], device="cuda")) y.backward(torch.zeros_like(y)) @@ -395,11 +402,11 @@ def setup_fp8_meta(): ], ) def test_check_for_weight_tensor_and_recipe_correspondence(self, model_init_recipe): - with fp8_model_init(enabled=True, recipe=model_init_recipe): + with quantized_model_init(enabled=True, recipe=model_init_recipe): linear = Linear(32, 32).cuda() x = torch.randn(32, 32, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()): + with te.autocast(enabled=True, recipe=DelayedScaling()): with pytest.raises(RuntimeError) as excinfo: _ = linear(x) assert "Recipe mismatch for " in str(excinfo.value) @@ -438,7 +445,7 @@ def test_dynamic_recipe_update( # Run initial iterations with DelayedScaling for _ in range(3): x = torch.randn(batch_size, in_features, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=initial_recipe): + with te.autocast(enabled=True, recipe=initial_recipe): y = linear(x) loss = y.mean() loss.backward() @@ -455,7 +462,7 @@ def test_dynamic_recipe_update( if i == 0: # Expect a warning on the first iteration with the new recipe with pytest.warns(UserWarning, match="Recipe type changed"): - with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + with te.autocast(enabled=True, recipe=target_recipe): y = linear(x) for quantizer in linear.quantizers["scaling_fwd"]: assert isinstance(quantizer, expected_quantizer_type) @@ -463,7 +470,7 @@ def test_dynamic_recipe_update( # No warning expected on subsequent iterations with warnings.catch_warnings(): warnings.simplefilter("error") # Raise error if unexpected warning occurs - with fp8_autocast(enabled=True, fp8_recipe=target_recipe): + with te.autocast(enabled=True, recipe=target_recipe): y = linear(x) loss = y.mean() loss.backward() @@ -487,7 +494,7 @@ def test_quantizer_update(self, module_class): batch_size = 32 recipe = DelayedScaling(amax_history_len=1024) - with fp8_model_init(recipe=recipe): + with quantized_model_init(recipe=recipe): if module_class == GroupedLinear: module = module_class(1, in_features, out_features).cuda() else: @@ -495,10 +502,43 @@ def test_quantizer_update(self, module_class): x = torch.randn(batch_size, in_features, device="cuda") recipe = DelayedScaling(amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=recipe): + with te.autocast(enabled=True, recipe=recipe): warn_msg = "Quantizer is being updated, this may affect model behavior" with pytest.warns(UserWarning, match=warn_msg): if module_class == GroupedLinear: y = module(x, [batch_size]) else: y = module(x) + + +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "M, N", + [ + # full tile cases + (128, 128), + (256, 1024), + (1024, 256), + # Padding required cases + (256, 272), + (304, 304), + (320, 256), + # # largest tile + (8192, 8192), + ], +) +def test_fp4_dequantize(dtype, M, N): + q = NVFP4Quantizer() + a = torch.rand((M, N)).cuda().to(dtype=dtype) + starting_tensor = q(a) + dequantized_tensor = starting_tensor.dequantize() + new_tensor = q(dequantized_tensor) + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) + new_dequantized_tensor = new_tensor.dequantize() + torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index a7d762c3d..b08d812cd 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -11,18 +11,16 @@ from torch.utils.cpp_extension import IS_HIP_EXTENSION import os -import transformer_engine.pytorch -from transformer_engine.pytorch.fp8 import ( - fp8_autocast, - FP8GlobalStateManager, - fp8_model_init, -) +import transformer_engine +import transformer_engine.pytorch as te +from transformer_engine.pytorch.quantization import FP8GlobalStateManager from transformer_engine.pytorch.utils import ( init_method_normal, scaled_init_method_normal, - is_bf16_compatible, ) from transformer_engine.pytorch import ( + autocast, + quantized_model_init, LayerNormLinear, Linear, GroupedLinear, @@ -30,26 +28,25 @@ TransformerLayer, RMSNorm, LayerNorm, + Float8CurrentScalingQuantizer, + Float8Quantizer, + Float8Tensor, + MXFP8Tensor, + checkpoint, + QuantizedTensor, + is_bf16_available, ) from transformer_engine.common import recipe import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.module.base import get_workspace -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import ( - Float8CurrentScalingQuantizer, - Float8Quantizer, - Float8Tensor, -) -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data -from transformer_engine.pytorch.distributed import checkpoint from utils import ModelConfig # Only run FP8 tests on supported devices. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -90,9 +87,19 @@ def is_fp8_supported(config: ModelConfig): "large": ModelConfig(2, 128, 4, 128, num_layers=1), } + +def nvfp4_vanilla(): + nvfp4_recipe = recipe.NVFP4BlockScaling() + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) + fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -101,7 +108,7 @@ def is_fp8_supported(config: ModelConfig): fp8_recipes.append(None) param_types = [torch.float32, torch.float16] -if is_bf16_compatible(): # bf16 requires sm_80 or higher +if is_bf16_available(): # bf16 requires sm_80 or higher param_types.append(torch.bfloat16) all_boolean = [True, False] @@ -118,6 +125,7 @@ def is_fp8_supported(config: ModelConfig): "sreglu", "silu", "swiglu", + "clamped_swiglu", ] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -153,7 +161,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): use_fp8 = fp8_recipe is not None with torch.autocast(device_type="cuda", enabled=True, dtype=dtype): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() @@ -192,7 +200,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci p.main_grad = torch.zeros_like(p) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -220,7 +228,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states) loss = te_out.sum() loss.backward() @@ -246,7 +254,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) loss = te_out.sum() loss.backward() @@ -278,7 +286,7 @@ def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): te_out = block( te_inp_hidden_states, attention_mask=te_inp_attn_mask, @@ -307,7 +315,7 @@ def _test_sanity_common( _disable_wgrads(block) use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): if not microbatching: te_out = block(te_inp) else: @@ -382,6 +390,8 @@ def test_sanity_layernorm_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -410,6 +420,8 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) @@ -440,9 +452,11 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): + with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_linear = Linear( config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -450,7 +464,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): out = te_linear(inp_hidden_states) loss = out.sum() loss.backward() @@ -479,9 +493,11 @@ def test_sanity_grouped_linear( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4(): + pytest.skip("NVFP4 not supported for grouped linear") use_fp8 = fp8_recipe is not None - with fp8_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): + with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_grouped_linear = GroupedLinear( num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype ).cuda() @@ -497,7 +513,7 @@ def test_sanity_grouped_linear( elif empty_split == "middle": m_splits[num_gemms // 2] = 0 - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with autocast(enabled=use_fp8, recipe=fp8_recipe): out = te_grouped_linear(inp_hidden_states, m_splits) loss = out.sum() loss.backward() @@ -529,11 +545,13 @@ def test_sanity_layernorm_mlp( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - + activation_params = None if activation != "clamped_swiglu" else {"limit": 7.0, "alpha": 1.702} block = LayerNormMLP( config.hidden_size, 4 * config.hidden_size, @@ -541,6 +559,7 @@ def test_sanity_layernorm_mlp( output_layer_init_method=output_layer_init_method, zero_centered_gamma=zero_centered_gamma, activation=activation, + activation_params=activation_params, normalization=normalization, params_dtype=dtype, device="cuda", @@ -571,6 +590,8 @@ def test_sanity_gpt( if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -632,6 +653,8 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -686,6 +709,8 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): pytest.skip(reason_for_no_fp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -737,6 +762,8 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -767,6 +794,8 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -801,6 +830,8 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -835,6 +866,8 @@ def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgra if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") + if fp8_recipe.nvfp4() and dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") sigma = 0.023 init_method = init_method_normal(sigma) @@ -945,9 +978,9 @@ def test_replace_raw_data_for_float8tensor(): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -def test_fp8_model_init_high_precision_init_val(): - """Test fp8_model_init with preserve_high_precision_init_val=True""" - with fp8_model_init(preserve_high_precision_init_val=True): +def test_quantized_model_init_high_precision_init_val(): + """Test quantized_model_init with preserve_high_precision_init_val=True""" + with quantized_model_init(preserve_high_precision_init_val=True): model = Linear(768, 768) weight = model.weight @@ -1020,7 +1053,7 @@ def test_linear_frozen_weights_memory_default_recipe(): linear.weight.requires_grad = False # Forward and backward pass with FP8 - with fp8_autocast(): + with autocast(): o = linear(x) g_o = torch.randn_like(o) @@ -1074,7 +1107,7 @@ def test_inference_mode( # Construct module module = None with torch.no_grad(): - with fp8_model_init(enabled=with_quantization, recipe=quantization_recipe): + with quantized_model_init(enabled=with_quantization, recipe=quantization_recipe): if module_name == "Linear": module = Linear(hidden_size, hidden_size) elif module_name == "LayerNormLinear": @@ -1109,6 +1142,6 @@ def check_weights(): kwargs = {} if module_name == "GroupedLinear": kwargs["m_splits"] = [sequence_length] - with fp8_autocast(enabled=with_quantization, fp8_recipe=quantization_recipe): + with autocast(enabled=with_quantization, recipe=quantization_recipe): y = module(x, **kwargs) check_weights() diff --git a/tests/pytorch/triton_kernels/test_cast.py b/tests/pytorch/triton_kernels/test_cast.py index 3f725c496..f85773d65 100644 --- a/tests/pytorch/triton_kernels/test_cast.py +++ b/tests/pytorch/triton_kernels/test_cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import pytest @@ -10,7 +10,7 @@ from transformer_engine.pytorch.triton_kernels.common import te_dtype_to_torch_dtype import transformer_engine_torch as tex from test_common import te_compare_results, fill_uniform, get_tolerances -from transformer_engine.pytorch.fp8 import fp8_autocast +from transformer_engine.pytorch.fp8 import autocast from transformer_engine.common import recipe from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type @@ -43,7 +43,7 @@ def test_quantize(scaling, shape, in_dtype, out_dtype): triton_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda") tex_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=out_dtype, device="cuda") - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): quantized_out_triton = te_quantize_triton(input_tensor, quantizer=triton_quantizer) quantized_out_tex = tex.quantize(input_tensor, tex_quantizer) @@ -187,13 +187,13 @@ def test_amax_atomic_vs_two_stage(shape, in_dtype, out_dtype): # atomic amax os.environ[env_key] = "1" - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): out_atomic = te_quantize_triton(input_tensor, quantizer=quantizer_atomic) # 2-stage amax os.environ[env_key] = "0" - with fp8_autocast(enabled=True, fp8_recipe=recipe.Float8CurrentScaling()): + with autocast(enabled=True, recipe=recipe.Float8CurrentScaling()): out_2stage = te_quantize_triton(input_tensor, quantizer=quantizer_2stage) te_compare_results( diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 684c15737..ed5a12995 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -9,23 +9,24 @@ import logging import os from contextlib import contextmanager +from typing import Optional, Tuple, Dict, Any, List -import pytest import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine -import transformer_engine.common.recipe -import transformer_engine.pytorch as te -from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type import transformer_engine_torch as tex +from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch import InferenceParams from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends from transformer_engine.pytorch.attention.dot_product_attention.utils import ( get_attention_backend, AttentionParams, AttentionLogging, + check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from transformer_engine.pytorch.utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type torch_float8_e4m3_type = get_torch_float8_e4m3_type() torch_float8_e5m2_type = get_torch_float8_e5m2_type() @@ -78,6 +79,8 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: # Transformer Engine dtypes if isinstance(dtype, tex.DType): + if dtype == tex.DType.kFloat4E2M1: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.25 dtype = { tex.DType.kByte: torch.uint8, tex.DType.kInt32: torch.int32, @@ -97,13 +100,28 @@ def dtype_tols(dtype: torch.dtype | tex.DType) -> dict[str, float]: return dict(rtol=1.3e-6, atol=1e-5) if dtype == torch.float64: return dict(rtol=1e-7, atol=1e-7) - if dtype == torch.float8_e4m3fn or dtype == torch.float8_e4m3fnuz: + if dtype in torch_float8_e4m3_type: return dict(rtol=0.125, atol=0.0675) # epsilon = 0.0625 - if dtype == torch.float8_e5m2 or dtype == torch.float8_e5m2fnuz: - return dict(rtol=0.25, atol=0.125) # epsilon = 0.152 + if dtype in torch_float8_e5m2_type: + return dict(rtol=0.25, atol=0.125) # epsilon = 0.125 raise ValueError(f"Unsupported dtype ({dtype})") +def quantization_tols(name: str) -> dict[str, float]: + """Estimated numerical error for a quantization scheme""" + if name in ( + "fp8", + "fp8_delayed_scaling", + "fp8_current_scaling", + "mxfp8", + "mxfp8_block_scaling", + ): + return dtype_tols(tex.DType.kFloat8E4M3) + if name == "nvfp4": + return dtype_tols(tex.DType.kFloat4E2M1) + raise ValueError(f"Unsupported quantization scheme ({name})") + + def make_recipe(name: Optional[str]) -> Optional[Recipe]: """Make recipe for quantization scheme""" if name is None: @@ -123,6 +141,12 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() + if name == "nvfp4": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + ) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -143,6 +167,31 @@ def reset_rng_states() -> None: torch.cuda.set_rng_state(cuda_rng_state) +def compare_and_assert(a, b, name_a, name_b, atol, rtol, rmse_tol, is_fp8): + if not is_fp8: + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + return + + try: + if a.dtype != b.dtype: + a = a.to(b.dtype) + torch.testing.assert_close(a, b, atol=atol, rtol=rtol) + except Exception as e: + logging.debug(e) + + rmse = torch.sqrt((a - b).square().mean()).item() + logging.debug(name_a + " vs " + name_b + " RMSE: {:.6f}".format(rmse)) + rmse_range = max(a.max().item(), b.max().item()) - min(a.min().item(), b.min().item()) + assert rmse < rmse_tol * rmse_range, ( + name_a + + " vs " + + name_b + + " RMSE {:.5f} is over tolerance {:.5f} ({:.5f} * {:.5f})".format( + rmse, rmse_tol * rmse_range, rmse_tol, rmse_range + ) + ) + + class ModelConfig: def __init__( self, @@ -153,12 +202,16 @@ def __init__( max_seqlen_kv: int = None, num_gqa_groups: int = None, head_dim_v: int = None, + softmax_type: str = "vanilla", dropout_p: float = 0.0, attn_mask_type: str = "no_mask", attn_bias_type: str = "no_bias", alibi_type: str = "none", bias_shape: str = "1hss", window_size: Tuple[int, int] = (-1, -1), + context_parallel: bool = False, + cp_comm_type: str = "p2p", + return_max_logit=False, total_requests: int = None, max_ctx_len: int = None, num_layers: int = 1, @@ -177,13 +230,17 @@ def __init__( self.kv_channels = (self.head_dim_qk, self.head_dim_v) self.hidden_size = self.num_heads * self.head_dim_qk self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.softmax_type = softmax_type self.dropout_p = dropout_p self.attn_mask_type = attn_mask_type self.attn_bias_type = attn_bias_type self.alibi_type = alibi_type self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" self.bias_shape = bias_shape - self.window_size = window_size + self.window_size = check_set_window_size(self.attn_mask_type, window_size) + self.context_parallel = context_parallel + self.cp_comm_type = cp_comm_type + self.return_max_logit = return_max_logit self.total_requests = total_requests self.max_ctx_len = max_ctx_len self.num_layers = num_layers @@ -221,9 +278,7 @@ def get_available_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), pad_between_seqs: bool = False, - context_parallel: bool = False, deterministic: bool = False, fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, @@ -276,19 +331,22 @@ def test(): head_dim_qk=config.head_dim_qk, head_dim_v=config.head_dim_v, attn_mask_type=config.attn_mask_type, - window_size=window_size, + window_size=config.window_size, alibi_slopes_shape=alibi_slopes_shape, core_attention_bias_type=config.attn_bias_type, core_attention_bias_shape=core_attention_bias_shape, core_attention_bias_requires_grad=core_attention_bias_requires_grad, pad_between_seqs=pad_between_seqs, attention_dropout=config.dropout_p, - context_parallel=context_parallel, + context_parallel=config.context_parallel, + cp_comm_type=config.cp_comm_type, deterministic=deterministic, fp8=fp8, fp8_meta=fp8_meta, is_training=is_training, inference_params=inference_params, + softmax_type=config.softmax_type, + return_max_logit=config.return_max_logit, ) ( use_flash_attention, diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cefec6d06..831de2b45 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -35,15 +35,6 @@ endif() # Language options if(USE_CUDA) # Removed indent to minimize code diff with NV upstream -if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) - set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) - elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) - else () - set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) - endif() -endif() set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) @@ -60,8 +51,62 @@ project(transformer_engine LANGUAGES CUDA CXX) # CUDA Toolkit find_package(CUDAToolkit REQUIRED) -if (CUDAToolkit_VERSION VERSION_LESS 12.0) - message(FATAL_ERROR "CUDA 12.0+ is required, but found CUDA ${CUDAToolkit_VERSION}") +if (CUDAToolkit_VERSION VERSION_LESS 12.1) + message(FATAL_ERROR "CUDA 12.1+ is required, but found CUDA ${CUDAToolkit_VERSION}") +endif() + +# Process GPU architectures +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 13.0) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + elseif (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 70 80 89 90) + endif() +endif() + +# Process CMAKE_CUDA_ARCHITECTURES to separate generic and specific architectures +set(NVTE_GENERIC_ARCHS) +set(NVTE_SPECIFIC_ARCHS) + +# Check for architecture 100 +list(FIND CMAKE_CUDA_ARCHITECTURES "100" arch_100_index) +if(NOT arch_100_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "100") + list(APPEND NVTE_GENERIC_ARCHS "100") + list(APPEND NVTE_SPECIFIC_ARCHS "100a") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND NVTE_SPECIFIC_ARCHS "103a") + endif() +endif() + +# Check for architecture 101 (if we see this we are in toolkit <= 12.9) +list(FIND CMAKE_CUDA_ARCHITECTURES "101" arch_101_index) +if(NOT arch_101_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "101") + list(APPEND NVTE_GENERIC_ARCHS "101") + list(APPEND NVTE_SPECIFIC_ARCHS "101a") +endif() + +# Check for architecture 110 (if we see this we are in toolkit >= 13.0) +list(FIND CMAKE_CUDA_ARCHITECTURES "110" arch_110_index) +if(NOT arch_110_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "110") + list(APPEND NVTE_GENERIC_ARCHS "110") + list(APPEND NVTE_SPECIFIC_ARCHS "110f") +endif() + +# Check for architecture 120 +list(FIND CMAKE_CUDA_ARCHITECTURES "120" arch_120_index) +if(NOT arch_120_index EQUAL -1) + list(REMOVE_ITEM CMAKE_CUDA_ARCHITECTURES "120") + list(APPEND NVTE_GENERIC_ARCHS "120") + if(CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.9) + list(APPEND NVTE_SPECIFIC_ARCHS "120f") + else() + list(APPEND NVTE_SPECIFIC_ARCHS "120a") + endif() endif() # cuDNN frontend API @@ -123,9 +168,23 @@ find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) # Configure Transformer Engine library include_directories(${PROJECT_SOURCE_DIR}/..) set(transformer_engine_SOURCES) +set(transformer_engine_cpp_sources) +set(transformer_engine_cuda_sources) +set(transformer_engine_cuda_arch_specific_sources) + # Source files in both cuda and rocm -list(APPEND transformer_engine_SOURCES +list(APPEND transformer_engine_cpp_sources transformer_engine.cpp + gemm/config.cpp + normalization/common.cpp + normalization/layernorm/ln_api.cpp + normalization/rmsnorm/rmsnorm_api.cpp + util/cuda_driver.cpp + util/cuda_runtime.cpp + util/multi_stream.cpp + util/rtc.cpp) + +list(APPEND transformer_engine_cuda_sources common.cu multi_tensor/adam.cu multi_tensor/compute_scale.cu @@ -138,29 +197,17 @@ list(APPEND transformer_engine_SOURCES transpose/transpose_fusion.cu transpose/multi_cast_transpose.cu transpose/swap_first_dims.cu - activation/gelu.cu dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu - activation/relu.cu - activation/swiglu.cu gemm/cublaslt_gemm.cu - normalization/common.cpp - normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu normalization/layernorm/ln_fwd_cuda_kernel.cu - normalization/rmsnorm/rmsnorm_api.cpp normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu permutation/permutation.cu - util/cast.cu util/padding.cu - util/cuda_driver.cpp - util/cuda_runtime.cpp - util/multi_stream.cpp - util/rtc.cpp - swizzle/swizzle.cu fused_softmax/scaled_masked_softmax.cu fused_softmax/scaled_upper_triang_masked_softmax.cu fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -171,24 +218,91 @@ list(APPEND transformer_engine_SOURCES recipe/current_scaling.cu recipe/delayed_scaling.cu recipe/fp8_block_scaling.cu) + +list(APPEND transformer_engine_cuda_arch_specific_sources + cast/cast.cu + activation/gelu.cu + activation/relu.cu + activation/swiglu.cu) + if(USE_CUDA) -# Removed indent to minimize code diff with NV upstream -# Files unique in cuda building -list(APPEND transformer_engine_SOURCES - cudnn_utils.cpp - transpose/quantize_transpose_square_blockwise.cu - transpose/quantize_transpose_vector_blockwise.cu - fused_attn/fused_attn_f16_max512_seqlen.cu - fused_attn/fused_attn_f16_arbitrary_seqlen.cu - fused_attn/fused_attn_fp8.cu - fused_attn/fused_attn.cpp - fused_attn/utils.cu - gemm/cutlass_grouped_gemm.cu - util/cuda_nvml.cpp - comm_gemm_overlap/userbuffers/ipcsocket.cc - comm_gemm_overlap/userbuffers/userbuffers-host.cpp - comm_gemm_overlap/userbuffers/userbuffers.cu - comm_gemm_overlap/comm_gemm_overlap.cpp) +#NV specific source codes + list(APPEND transformer_engine_cpp_sources + cudnn_utils.cpp + fused_attn/fused_attn.cpp + util/cuda_nvml.cpp + comm_gemm_overlap/userbuffers/ipcsocket.cc + comm_gemm_overlap/userbuffers/userbuffers-host.cpp + comm_gemm_overlap/comm_gemm_overlap.cpp) + list(APPEND transformer_engine_cuda_sources + transpose/quantize_transpose_vector_blockwise.cu + fused_attn/fused_attn_f16_max512_seqlen.cu + fused_attn/fused_attn_f16_arbitrary_seqlen.cu + fused_attn/fused_attn_fp8.cu + fused_attn/utils.cu + swizzle/swizzle.cu + swizzle/swizzle_block_scaling.cu + recipe/nvfp4.cu + comm_gemm_overlap/userbuffers/userbuffers.cu) + list(APPEND transformer_engine_cuda_arch_specific_sources + gemm/cutlass_grouped_gemm.cu + transpose/quantize_transpose_square_blockwise.cu + transpose/quantize_transpose_vector_blockwise_fp4.cu + hadamard_transform/hadamard_transform.cu + hadamard_transform/hadamard_transform_cast_fusion.cu) +else() +#ROCm specific source codes + list(APPEND transformer_engine_cpp_sources + fused_attn_rocm/fused_attn.cpp + gemm/rocm_gemm.cu + amd_detail/system.cpp) + list(APPEND transformer_engine_cuda_sources + fused_attn_rocm/fused_attn_aotriton.cpp + fused_attn_rocm/fused_attn_ck.cpp + fused_attn_rocm/utils.cpp) +endif() + + +# Compiling the files with the worst compilation time first to hopefully overlap +# better with the faster-compiling cpp files +list(APPEND transformer_engine_SOURCES ${transformer_engine_cuda_arch_specific_sources} + ${transformer_engine_cuda_sources} + ${transformer_engine_cpp_sources}) + +if(USE_CUDA) +# Set compile options for CUDA sources with generic architectures +foreach(cuda_source IN LISTS transformer_engine_cuda_sources) + set(arch_compile_options) + foreach(arch IN LISTS NVTE_GENERIC_ARCHS) + list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}") + endforeach() + + if(arch_compile_options) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_OPTIONS ${arch_compile_options} + ) + endif() +endforeach() + +# Set compile options for CUDA sources with specific architectures +foreach(cuda_source IN LISTS transformer_engine_cuda_arch_specific_sources) + set(arch_compile_options) + foreach(arch IN LISTS NVTE_SPECIFIC_ARCHS) + list(APPEND arch_compile_options "--generate-code=arch=compute_${arch},code=sm_${arch}") + endforeach() + + if(arch_compile_options) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_OPTIONS ${arch_compile_options} + ) + endif() +endforeach() if (NVTE_WITH_CUBLASMP) list(APPEND transformer_engine_SOURCES @@ -196,14 +310,8 @@ list(APPEND transformer_engine_SOURCES endif() add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) + else() - list(APPEND transformer_engine_SOURCES - fused_attn_rocm/fused_attn.cpp - fused_attn_rocm/fused_attn_aotriton.cpp - fused_attn_rocm/fused_attn_ck.cpp - fused_attn_rocm/utils.cpp - gemm/rocm_gemm.cu - amd_detail/system.cpp) # process source code files set(TE ${CMAKE_CURRENT_SOURCE_DIR}/../..) @@ -235,22 +343,19 @@ else() message(STATUS "nvte hipified sources: ${te_hip_sources}") add_library(transformer_engine SHARED ${te_hip_sources}) - target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") + target_include_directories(transformer_engine PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) endif() target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") if (USE_CUDA) -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "gemm/cutlass_grouped_gemm.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") -else() - message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") -endif() +# CUTLASS kernels require SM90a and cause hang in debug build +set_property( + SOURCE gemm/cutlass_grouped_gemm.cu + APPEND + PROPERTY + COMPILE_OPTIONS "--generate-code=arch=compute_90a,code=sm_90a;-g0") endif() #USE_CUDA # Configure dependencies @@ -261,7 +366,7 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") @@ -407,37 +512,43 @@ target_include_directories(transformer_engine PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/string_headers") # Compiler options -set_source_files_properties(fused_softmax/scaled_masked_softmax.cu - fused_softmax/scaled_upper_triang_masked_softmax.cu - fused_softmax/scaled_aligned_causal_masked_softmax.cu - multi_tensor/adam.cu - multi_tensor/compute_scale.cu - multi_tensor/l2norm.cu - multi_tensor/scale.cu - multi_tensor/sgd.cu - fused_attn/flash_attn.cu - fused_attn/context_parallel.cu - fused_attn/kv_cache.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") if(USE_CUDA) +set(nvte_sources_with_fast_math) +list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu + fused_softmax/scaled_upper_triang_masked_softmax.cu + fused_softmax/scaled_aligned_causal_masked_softmax.cu + multi_tensor/adam.cu + multi_tensor/compute_scale.cu + multi_tensor/l2norm.cu + multi_tensor/scale.cu + multi_tensor/sgd.cu + fused_attn/flash_attn.cu + fused_attn/context_parallel.cu + fused_attn/kv_cache.cu) + option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) - set_source_files_properties(activation/gelu.cu - activation/relu.cu - activation/swiglu.cu - util/cast.cu - PROPERTIES - COMPILE_OPTIONS "--use_fast_math") + list(APPEND nvte_sources_with_fast_math activation/gelu.cu + activation/relu.cu + activation/swiglu.cu) endif() + +foreach(cuda_source IN LISTS nvte_sources_with_fast_math) + set_property( + SOURCE ${cuda_source} + APPEND + PROPERTY + COMPILE_OPTIONS "--use_fast_math") +endforeach() + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr") set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -O3") # Number of parallel build jobs -if(ENV{MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{MAX_JOBS}") -elseif(ENV{NVTE_BUILD_MAX_JOBS}) - set(BUILD_JOBS_STR "$ENV{NVTE_BUILD_MAX_JOBS}") +if($ENV{MAX_JOBS}) + set(BUILD_JOBS_STR $ENV{MAX_JOBS}) +elseif($ENV{NVTE_BUILD_MAX_JOBS}) + set(BUILD_JOBS_STR $ENV{NVTE_BUILD_MAX_JOBS}) else() set(BUILD_JOBS_STR "max") endif() diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 26672bafd..f0335f44c 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -10,23 +10,19 @@ import functools import glob import importlib -from importlib.metadata import version, metadata, PackageNotFoundError -import logging +from importlib.metadata import version, distribution, PackageNotFoundError import os from pathlib import Path import platform import subprocess import sys import sysconfig -from typing import Optional +from typing import Optional, Tuple import transformer_engine -_logger = logging.getLogger(__name__) - - @functools.lru_cache(maxsize=None) -def _is_pip_package_installed(package) -> bool: +def _is_package_installed(package) -> bool: """Check if the given package is installed via pip.""" # This is needed because we only want to return true @@ -34,12 +30,34 @@ def _is_pip_package_installed(package) -> bool: # if it's importable in the current directory due to # the presence of the shared library module. try: - metadata(package) + distribution(package) except PackageNotFoundError: return False return True +@functools.lru_cache(maxsize=None) +def _is_package_installed_from_wheel(package) -> bool: + """Check if the given package is installed via PyPI.""" + + if not _is_package_installed(package): + return False + + te_dist = distribution(package) + te_wheel_file = "" + for file_path in te_dist.files: + if file_path.name == "WHEEL": + te_wheel_file = te_dist.locate_file("") / file_path + if not te_wheel_file: + return False + + with te_wheel_file.open("r") as f: + for line in f: + if line.startswith("Root-Is-Purelib:"): + return line.strip().split(":")[1].strip().lower() == "true" + return False + + @functools.lru_cache(maxsize=None) def _find_shared_object_in_te_dir(te_path: Path, prefix: str) -> Optional[Path]: """ @@ -115,6 +133,21 @@ def _get_shared_object_file(library: str) -> Path: ) +def get_te_core_package_info() -> Tuple[bool, str, str]: + """ + Check if Tranformer Engine core package is installed. + Returns the module name and version if found. + """ + + te_core_packages = ("transformer-engine-cu12", "transformer-engine-cu13") + if te_rocm_build: + te_core_packages = ("transformer-engine-rocm") + for package in te_core_packages: + if _is_package_installed(package): + return True, package, version(package) + return False, "", "" + + @functools.lru_cache(maxsize=None) def load_framework_extension(framework: str) -> None: """ @@ -133,41 +166,30 @@ def load_framework_extension(framework: str) -> None: if framework == "torch": extra_dep_name = "pytorch" - te_cuda_vers = "rocm" if te_rocm_build else "cu12" + # Find the TE packages. The core and framework packages can only be installed via PyPI. + # For the `transformer-engine` package, we need to check explicity. + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() + te_framework_installed = _is_package_installed(module_name) + te_installed = _is_package_installed("transformer_engine") + te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") + + assert te_installed, "Could not find `transformer_engine`." # If the framework extension pip package is installed, it means that TE is installed via # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework - # extension are all installed via PyPI and have matching version. - if _is_pip_package_installed(module_name): - assert _is_pip_package_installed( - "transformer_engine" - ), "Could not find `transformer-engine`." - assert _is_pip_package_installed( - f"transformer_engine_{te_cuda_vers}" - ), f"Could not find `transformer-engine-{te_cuda_vers}`." - assert ( - version(module_name) - == version("transformer-engine") - == version(f"transformer-engine-{te_cuda_vers}") - ), ( - "TransformerEngine package version mismatch. Found" + # extension are all installed via PyPI and have matching versions. + if te_framework_installed: + assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`." + + assert version(module_name) == version("transformer-engine") == te_core_version, ( + "Transformer Engine package version mismatch. Found" f" {module_name} v{version(module_name)}, transformer-engine" - f" v{version('transformer-engine')}, and transformer-engine-{te_cuda_vers}" - f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'" + f" v{version('transformer-engine')}, and {te_core_package_name}" + f" v{te_core_version}. Install transformer-engine using " + f"'pip3 install --no-build-isolation transformer-engine[{extra_dep_name}]==VERSION'" ) - # If the core package is installed via PyPI, log if - # the framework extension is not found from PyPI. - # Note: Should we error? This is a rare use case. - if _is_pip_package_installed(f"transformer-engine-{te_cuda_vers}"): - if not _is_pip_package_installed(module_name): - _logger.info( - "Could not find package %s. Install transformer-engine using " - f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'", - module_name, - ) - # After all checks are completed, load the shared object file. spec = importlib.util.spec_from_file_location(module_name, _get_shared_object_file(framework)) solib = importlib.util.module_from_spec(spec) @@ -175,6 +197,35 @@ def load_framework_extension(framework: str) -> None: spec.loader.exec_module(solib) +def sanity_checks_for_pypi_installation() -> None: + """Ensure that package is installed correctly if using PyPI.""" + + te_core_installed, te_core_package_name, te_core_version = get_te_core_package_info() + te_installed = _is_package_installed("transformer_engine") + te_installed_via_pypi = _is_package_installed_from_wheel("transformer_engine") + + assert te_installed, "Could not find `transformer-engine`." + + # If the core package is installed via PyPI. + if te_core_installed: + assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + assert version("transformer-engine") == te_core_version, ( + "Transformer Engine package version mismatch. Found " + f"transformer-engine v{version('transformer-engine')} " + f"and {te_core_package_name} v{te_core_version}." + ) + + # Only the metapackage is found, invalid usecase. + elif te_installed_via_pypi: + raise RuntimeError( + "Found empty `transformer-engine` meta package installed. " + "Install `transformer-engine` with framework extensions via" + "'pip3 install --no-build-isolation transformer-engine[pytorch,jax]==VERSION'" + " or 'pip3 install transformer-engine[core]` for the TE core lib only. The `core_cu12`" + " or `core_cu13` extra deps can be used to specify CUDA version for the TE core lib." + ) + + @functools.lru_cache(maxsize=None) def _get_sys_extension() -> str: """File extension for shared objects.""" @@ -257,9 +308,7 @@ def _load_cudnn(): return handle # Attempt to locate libcudnn via ldconfig - libs = subprocess.check_output( - f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True - ) + libs = subprocess.check_output(["ldconfig", "-p"]) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: @@ -289,9 +338,7 @@ def _load_nvrtc(): return handle # Attempt to locate NVRTC via ldconfig - libs = subprocess.check_output( - f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True - ) + libs = subprocess.check_output(["ldconfig", "-p"]) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: @@ -321,9 +368,7 @@ def _load_curand(): return handle # Attempt to locate cuRAND via ldconfig - libs = subprocess.check_output( - f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True - ) + libs = subprocess.check_output(["ldconfig", "-p"]) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: @@ -358,9 +403,6 @@ def _load_core_library(): _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") - # Needed to find the correct headers for NVRTC kernels. - if not os.getenv("NVTE_CUDA_INCLUDE_DIR") and _nvidia_cudart_include_dir(): - os.environ["NVTE_CUDA_INCLUDE_DIR"] = _nvidia_cudart_include_dir() except (OSError, subprocess.CalledProcessError): pass finally: diff --git a/transformer_engine/common/activation/activation_template.h b/transformer_engine/common/activation/activation_template.h index 67f173a4a..7353c3e1d 100644 --- a/transformer_engine/common/activation/activation_template.h +++ b/transformer_engine/common/activation/activation_template.h @@ -14,26 +14,17 @@ #include #include +#include "../cast/dispatch/gated.cuh" +#include "../cast/dispatch/quantize.cuh" #include "../common.h" -#include "../util/cast_gated_kernels.cuh" -#include "../util/cast_kernels.cuh" -#include "../util/math.h" -#include "../util/vectorized_pointwise.h" namespace transformer_engine { template void act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DBIAS = false; - constexpr bool IS_DACT = false; constexpr bool IS_ACT = true; - constexpr NVTETensor dbias = nullptr; - constexpr NVTETensor workspace = nullptr; - constexpr const NVTETensor grad = nullptr; - - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_fwd_helper(input, output, nullptr, stream); } template @@ -42,31 +33,25 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, using namespace detail; constexpr bool IS_DBIAS = false; constexpr bool IS_DACT = true; - constexpr bool IS_ACT = false; constexpr NVTETensor dbias = nullptr; constexpr NVTETensor workspace = nullptr; - quantize_helper(input, grad, output, dbias, workspace, - nullptr, stream); + dispatch::quantize_bwd_helper(grad, input, output, dbias, workspace, + nullptr, stream); } template -void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { +void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DGATED = false; - constexpr NVTETensor grad = nullptr; - - quantize_gated_helper(grad, input, output, stream); + dispatch::quantize_gated_fwd_helper(input, output, p, stream); } template -void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, +void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) { using namespace detail; - constexpr bool IS_DGATED = true; - - quantize_gated_helper(grad, input, output, stream); + dispatch::quantize_gated_bwd_helper(grad, input, output, p, stream); } } // namespace transformer_engine diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index 0cf43007a..4979023ef 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -20,17 +20,32 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_geglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dgeglu); using namespace transformer_engine; - dgated_act_fn, dgelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dgelu>(grad, input, output, e, stream); } void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -46,15 +61,30 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dqgelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_qgeglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dqgeglu); using namespace transformer_engine; - dgated_act_fn, dqgelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dqgelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index a794b7315..c0ef9fd65 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -20,17 +20,32 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_drelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_reglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dreglu); using namespace transformer_engine; - dgated_act_fn, drelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, drelu>(grad, input, output, e, stream); } void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { @@ -46,15 +61,30 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsrelu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_sreglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dsreglu); using namespace transformer_engine; - dgated_act_fn, dsrelu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dsrelu>(grad, input, output, e, stream); } diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 819496474..6957a91e6 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -20,15 +20,47 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output dact_fn>(grad, input, output, stream); } +void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activation_input, + NVTETensor output, NVTETensor dbias, NVTETensor workspace, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias_dsilu); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = true; + + dispatch::quantize_bwd_helper>( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swiglu); using namespace transformer_engine; - gated_act_fn>(input, output, stream); + Empty e = {}; + gated_act_fn>(input, output, e, stream); } void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_dswiglu); using namespace transformer_engine; - dgated_act_fn, dsilu>(grad, input, output, stream); + Empty e = {}; + dgated_act_fn, dsilu>(grad, input, output, e, stream); +} + +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream) { + NVTE_API_CALL(nvte_clamped_swiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha}; + gated_act_fn>(input, output, param, stream); +} + +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream) { + NVTE_API_CALL(nvte_clamped_dswiglu); + using namespace transformer_engine; + ClampedSwiGLUParam param = {limit, alpha}; + dgated_act_fn, clamped_dsilu>( + grad, input, output, param, stream); } diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu new file mode 100644 index 000000000..7ecc05d2e --- /dev/null +++ b/transformer_engine/common/cast/cast.cu @@ -0,0 +1,106 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif //#ifndef __HIP_PLATFORM_AMD__ +#include +#include +#include + +#include "../common.h" +#include "../transpose/cast_transpose.h" +#include "../util/multi_stream.h" +#include "../utils.cuh" +#include "dispatch/dequantize.cuh" +#include "dispatch/quantize.cuh" +#include "transformer_engine/transpose.h" + +void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_fwd_helper(input, output, nullptr, stream); +} + +void nvte_quantize_noop(const NVTETensor input, NVTETensor output, NVTETensor noop, + cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_noop); + using namespace transformer_engine; + + // Create config with noop tensor + QuantizationConfig quant_config; + quant_config.noop_tensor = noop; + + nvte_quantize_v2(input, output, reinterpret_cast(&quant_config), stream); +} + +void nvte_quantize_v2(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_v2); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + dispatch::quantize_fwd_helper(input, output, quant_config, stream); +} + +void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias, + NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_quantize_dbias); + using namespace transformer_engine; + + constexpr bool IS_DBIAS = true; + constexpr bool IS_DACT = false; + constexpr const NVTETensor activation_input = nullptr; + + dispatch::quantize_bwd_helper( + input, activation_input, output, dbias, workspace, nullptr, stream); +} + +void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_dequantize); + using namespace transformer_engine; + dispatch::dequantize_helper(*convertNVTETensorCheck(input), convertNVTETensorCheck(output), + stream); +} + +void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, + const NVTEQuantizationConfig quant_configs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_quantize); + using namespace transformer_engine; + + constexpr bool IS_ACT = false; + + const size_t num_streams = nvte_get_num_compute_streams(); + + int num_stream_used = std::min(num_streams, num_tensors); + // wait for current stream to finish + NVTE_CHECK_CUDA(cudaEventRecord(detail::get_compute_stream_event(0), stream)); + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaStreamWaitEvent(detail::get_compute_stream(s), detail::get_compute_stream_event(0))); + } + + for (int i = 0; i < num_tensors; i++) { + dispatch::quantize_fwd_helper( + inputs[i], outputs[i], quant_configs, detail::get_compute_stream(i % num_streams)); + } + + // record events on compute streams + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA( + cudaEventRecord(detail::get_compute_stream_event(s), detail::get_compute_stream(s))); + } + // wait for all compute streams to finish + for (int s = 0; s < num_stream_used; s++) { + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream, detail::get_compute_stream_event(s))); + } +} diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh new file mode 100644 index 000000000..540f7e252 --- /dev/null +++ b/transformer_engine/common/cast/core/common.cuh @@ -0,0 +1,101 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file common.cuh + * \brief Common functions in quantize. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ + +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif //#ifndef __HIP_PLATFORM_AMD__ +#include +#include + +#include "../../common.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace common { +inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { + const size_t N = product(t->data.shape); + const bool isFullTile = (N % elems_per_block == 0); + return isFullTile; +} + +inline bool dimensions_supported_by_TMA(const Tensor *const t) { + const size_t cols = t->flat_last_dim(); + constexpr size_t TMA_bytes = 16; + const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + return cols % alignment_requirement == 0; +} + +namespace kernel { + +constexpr size_t THREADS_PER_BLOCK = 256; +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, + const size_t rows, const size_t cols) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + thread_id * nvec; + OType *const thread_out_base = dbias_output + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} +} // namespace kernel + +template +void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); + const size_t reduce_dbias_num_blocks = DIVUP(cols, THREADS_PER_BLOCK * reduce_dbias_nvec); + + reduce_dbias_kernel + <<>>( + reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace common +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_CORE_COMMON_CUH_ diff --git a/transformer_engine/common/cast/dispatch/dequantize.cuh b/transformer_engine/common/cast/dispatch/dequantize.cuh new file mode 100644 index 000000000..4ba64ca97 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/dequantize.cuh @@ -0,0 +1,66 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize.cuh + * \brief Dequantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../fp8/dequantize_fp8.cuh" +#include "../mxfp8/dequantize_mxfp8.cuh" +#ifndef __HIP_PLATFORM_AMD__ +#include "../nvfp4/dequantize_nvfp4.cuh" +#endif //#ifndef __HIP_PLATFORM_AMD__ + +namespace transformer_engine { +namespace dispatch { + +inline void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + switch (input.scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); + NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + fp8::dequantize(input, output, stream); + break; + } + case NVTE_MXFP8_1D_SCALING: { +#ifndef __HIP_PLATFORM_AMD__ + if (is_supported_by_CC_100()) { +#endif //#ifndef __HIP_PLATFORM_AMD__ + mxfp8::dequantize(input, output, stream); +#ifndef __HIP_PLATFORM_AMD__ + } else { + NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); + } +#endif //#ifndef __HIP_PLATFORM_AMD__ + break; + } +#ifndef __HIP_PLATFORM_AMD__ + case NVTE_NVFP4_1D_SCALING: { + nvfp4::dequantize(input, output, stream); + break; + } +#endif //#ifndef __HIP_PLATFORM_AMD__ + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_DEQUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/dispatch/gated.cuh b/transformer_engine/common/cast/dispatch/gated.cuh new file mode 100644 index 000000000..8f236023b --- /dev/null +++ b/transformer_engine/common/cast/dispatch/gated.cuh @@ -0,0 +1,179 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated.cuh + * \brief Gated dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ + +#include + +#include "../../common.h" +#include "../../utils.cuh" +#include "../fp8/gated_fp8.cuh" +#include "../mxfp8/gated_mxfp8.cuh" + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_gated_fwd_helper(const NVTETensor nvte_input, NVTETensor nvte_output, ParamOP &p, + cudaStream_t stream) { + const Tensor input = *convertNVTETensorCheck(nvte_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim() / 2; + + NVTE_CHECK(input.flat_last_dim() % 2 == 0, + "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", + input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols, + "Wrong output shape. Expected (after flattening) [*, ", cols, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8::cast_gated_fwd(input, output, p, stream); +#else + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + Tensor dummy_grad_tensor; + fp8::cast_gated_tma(input, dummy_grad_tensor, + output, p, stream); + } else { + fp8::cast_gated_fwd(input, output, p, stream); + } +#endif //#ifdef __HIP_PLATFORM_AMD__ + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } +#ifdef __HIP_PLATFORM_AMD__ + //TODO: add gfx950 equivalent checking +#else + NVTE_CHECK(is_supported_by_CC_100(), + "Gated FWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); +#endif + Tensor dummy_grad_tensor; + mxfp8::quantize_gated(input, dummy_grad_tensor, + output, p, stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } +} + +template +void quantize_gated_bwd_helper(const NVTETensor nvte_grad, const NVTETensor nvte_gated_input, + NVTETensor nvte_output, ParamOP &p, cudaStream_t stream) { + const Tensor &grad = *(convertNVTETensorCheck(nvte_grad)); + const Tensor gated_input = *convertNVTETensorCheck(nvte_gated_input); + Tensor *output = convertNVTETensorCheck(nvte_output); + + CheckInputTensor(grad, "grad"); + CheckInputTensor(gated_input, "gated_input"); + CheckOutputTensor(*output, "output", /*allow_empty=*/false); + + NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even, but got ", + gated_input.flat_last_dim(), "."); + + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + + NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); + NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); + + NVTE_CHECK(grad.flat_first_dim() == rows, + "Wrong Grad shape. Expected first dimension (after flattening) [", rows, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + NVTE_CHECK(grad.flat_last_dim() == cols, + "Wrong Grad shape. Expected last dimension (after flattening) [", cols, ", *], got [", + grad.flat_first_dim(), ", ", grad.flat_last_dim(), "]."); + + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + NVTE_CHECK(output->flat_first_dim() == rows, "Wrong output shape. Expected (after flattening) [", + rows, ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(output->flat_last_dim() == cols * 2, + "Wrong output shape. Expected (after flattening) [*, ", cols * 2, "], got [", + output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); + NVTE_CHECK(gated_input.data.shape == output->data.shape, + "Gated input and output shapes must match. Input shape: ", gated_input.data.shape, + ", output shape: ", output->data.shape, "."); + + switch (output->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { +#ifdef __HIP_PLATFORM_AMD__ + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); +#else + const bool use_tma_kernels = (cols % 32 == 0) && is_supported_by_CC_100(); + if (use_tma_kernels) { + fp8::cast_gated_tma(gated_input, grad, output, p, + stream); + } else { + fp8::cast_gated_bwd(gated_input, grad, output, p, stream); + } +#endif //#ifdef __HIP_PLATFORM_AMD__ + break; + } + case NVTE_MXFP8_1D_SCALING: { + NVTE_CHECK(cols % 32 == 0, + "Invalid input shape. Expected the last dimension to be " + "divisible by 32, but got ", + cols, "."); + if (output->has_data()) { + NVTE_CHECK(is_fp8_dtype(output->data.dtype), + "The type of the output tensor should be FP8."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(is_fp8_dtype(output->columnwise_data.dtype), + "The type of the columnwise output tensor should be FP8."); + } +#ifdef __HIP_PLATFORM_AMD__ + // add gfx950 equivalent check +#else + NVTE_CHECK(is_supported_by_CC_100(), + "Gated BWD NVTE_MXFP8_1D_SCALING is only supported on SM 10.0+"); +#endif //#ifdef __HIP_PLATFORM_AMD__ + + mxfp8::quantize_gated(gated_input, grad, output, p, + stream); + break; + } + default: + NVTE_ERROR("Not supported scaling mode: " + to_string(output->scaling_mode) + "."); + } +} +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_GATED_CUH_ diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh new file mode 100644 index 000000000..8e8993668 --- /dev/null +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -0,0 +1,334 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize.cuh + * \brief Quantize dispatcher. + */ + +#ifndef TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ +#define TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ + +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/vectorized_pointwise.h" +#include "../core/common.cuh" +#include "../fp8/quantize_fp8.cuh" +#include "../mxfp8/quantize_mxfp8.cuh" +#ifndef __HIP_PLATFORM_AMD__ +#include "../nvfp4/quantize_nvfp4.cuh" +#include "../nvfp4/quantize_transpose_nvfp4.cuh" +#endif //#ifndef __HIP_PLATFORM_AMD__ + +namespace transformer_engine { +namespace dispatch { + +template +void quantize_fwd_helper(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *input_tensor = convertNVTETensorCheck(input); + Tensor *output_tensor = convertNVTETensorCheck(output); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_ACT) { + cast_transpose(*input_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *input_tensor, dummy_input_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + const Tensor *dummy_input_tensor = nullptr; + Tensor *dummy_dbias_tensor = nullptr; + Tensor *dummy_workspace_tensor = nullptr; + mxfp8::quantize( + *input_tensor, dummy_input_tensor, noop_tensor, output_tensor, dummy_dbias_tensor, + dummy_workspace_tensor, stream); + break; + } +#ifndef __HIP_PLATFORM_AMD__ + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK(!IS_ACT, "IS_ACT is not supported by FWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*input_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = input_tensor->flat_first_dim(); + int32_t cols = input_tensor->flat_last_dim(); + auto dtype = input_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/input_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_ACT, ParamOP, OP parameters support. + NVTE_CHECK(!IS_ACT, "IS_ACT is not implemented for FWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } +#endif//#ifndef __HIP_PLATFORM_AMD__ + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +template +void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETensor output, + NVTETensor dbias, NVTETensor workspace, + const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + using namespace detail; + + const Tensor *grad_tensor = convertNVTETensorCheck(grad); + const Tensor *input_tensor = convertNVTETensor(input); + + Tensor *output_tensor = convertNVTETensorCheck(output); + Tensor *dbias_tensor = convertNVTETensor(dbias); + Tensor *workspace_tensor = convertNVTETensor(workspace); + + // Quantization config + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + + // Noop flag + Tensor dummy_tensor; + Tensor *noop_tensor = &dummy_tensor; + if (quant_config_cpp.noop_tensor != nullptr) { + noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor); + } + + // Check for unsupported options + if (quant_config_cpp.stochastic_rounding) { + NVTE_CHECK(output_tensor->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Stochastic rounding is only supported for NVFP4 quantization."); + } + + NVTE_CHECK(output_tensor->has_data() || output_tensor->has_columnwise_data(), + "Either rowwise or columnwise output data need to be allocated."); + + // Dispatch to quantization kernel depending on data format + switch (output_tensor->scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: { + if (output_tensor->has_columnwise_data()) { + NVTE_CHECK(output_tensor->has_data(), + "Quantizing in only the columnwise direction not supported yet!"); + if constexpr (!IS_DBIAS && !IS_DACT) { + cast_transpose(*grad_tensor, *noop_tensor, output_tensor, stream); + } else { + cast_transpose_fused( + *grad_tensor, input_tensor, output_tensor, dbias_tensor, workspace_tensor, stream); + } + } else if (output_tensor->has_data()) { + fp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + } + break; + } + case NVTE_MXFP8_1D_SCALING: { + mxfp8::quantize( + *grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor, + stream); + break; + } +#ifndef __HIP_PLATFORM_AMD__ + case NVTE_NVFP4_1D_SCALING: { + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not supported by BWD NVTE_NVFP4_1D_SCALING"); + + // Check tensors + CheckNoopTensor(*noop_tensor, "cast_noop"); + CheckInputTensor(*grad_tensor, "input"); + CheckOutputTensor(*output_tensor, "output", false); + + // Choose kernel + int32_t rows = grad_tensor->flat_first_dim(); + int32_t cols = grad_tensor->flat_last_dim(); + auto dtype = grad_tensor->dtype(); + bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && + (cols % 32 == 0) && output_tensor->has_data(); + + // Launch NVFP4 quantize kernel + if (use_optimized_kernel) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_transpose( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else { + auto &global_amax = (output_tensor->amax.dptr != nullptr) ? output_tensor->amax + : output_tensor->columnwise_amax; + quantize_transpose_vector_blockwise_fp4( + /*input=*/grad_tensor->data, /*global_amax=*/global_amax, + /*scale_inv=*/output_tensor->scale_inv, + /*scale_inv_t=*/output_tensor->columnwise_scale_inv, + /*output=*/output_tensor->data, /*output_t=*/output_tensor->columnwise_data, + /*epsilon=*/0.0f, /*return_identity=*/output_tensor->has_data(), + /*return_transpose=*/output_tensor->has_columnwise_data(), /*pow2_scale=*/false, + /*swizzled_scale=*/false, + /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, + /*rng_state=*/quant_config_cpp.rng_state, + /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + } + break; + } + case NVTE_BLOCK_SCALING_2D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_2D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + quantize_transpose_square_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor->data, stream); + break; + } + case NVTE_BLOCK_SCALING_1D: { + // TODO(kwyss): IS_BIAS, IS_DACT, ParamOP, OP parameters support. + NVTE_CHECK((!IS_DBIAS && !IS_DACT), + "IS_DBIAS and IS_DACT are not implemented for BWD NVTE_BLOCK_SCALING_1D"); + bool force_pow_2_scales = quant_config_cpp.force_pow_2_scales; + float epsilon = quant_config_cpp.amax_epsilon; + FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; + FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; + if (output_tensor->has_data()) { + bool rowwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT + : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; + } + if (output_tensor->has_columnwise_data()) { + bool columnwise_compact = (quant_config_cpp.float8_block_scale_tensor_format == + Float8BlockScaleTensorFormat::COMPACT); + columnwise_option = columnwise_compact + ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT + : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; + } + quantize_transpose_vector_blockwise( + grad_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor->data, stream); + break; + } +#endif //#ifndef __HIP_PLATFORM_AMD__ + default: + NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); + } +} + +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DISPATCH_QUANTIZE_CUH_ diff --git a/transformer_engine/common/cast/fp8/dequantize_fp8.cuh b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh new file mode 100644 index 000000000..22a3929e3 --- /dev/null +++ b/transformer_engine/common/cast/fp8/dequantize_fp8.cuh @@ -0,0 +1,58 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_fp8.cuh + * \brief CUDA kernels to dequantize from FP8. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ + +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif // #ifndef __HIP_PLATFORM_AMD__ +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +struct DequantizeParam { + const float *scale_inv; +}; + +__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { + return value * (*(param.scale_inv)); +} + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + constexpr int nvec = 32 / sizeof(OType); + DequantizeParam p; p.scale_inv = reinterpret_cast(input.scale_inv.dptr); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), nullptr, + reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, + stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh new file mode 100644 index 000000000..aa46a574c --- /dev/null +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -0,0 +1,400 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file gated_fp8.cuh + * \brief CUDA kernels to cast to FP8 with gated activations. + */ + +#ifndef TRANSFORMER_ENGINE_GATED_FP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_FP8_CUH_ + +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif //#ifndef __HIP_PLATFORM_AMD__ +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +#ifndef __HIP_PLATFORM_AMD__ +namespace kernel { + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_PER_CHUNK = 512; +constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; +constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 +constexpr size_t BUFFERS_NUM = 2; +constexpr size_t BUFFER_DIM_Y = 32; +constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 +constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 +constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 + +constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 +constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 +static_assert(ITERATIONS >= 1); + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act, + const __grid_constant__ CUtensorMap tensor_map_output_gate, + float *const amax_ptr, float *const scale_inv_ptr, + const float *const scale_ptr, const size_t rows, const size_t cols, + const ParamOP p) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; + constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t grad_mem = IS_BWD ? buff_size_aligned_in : 0; + + constexpr size_t in_act_mem = buff_size_aligned_in; + constexpr size_t in_gate_mem = buff_size_aligned_in; + constexpr size_t in_mem = in_act_mem + in_gate_mem; + + constexpr size_t out_act_mem = buff_size_aligned_out; + constexpr size_t in_transaction_size = buff_elems * sizeof(IType); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_grad_sh = reinterpret_cast(dshmem); + IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); + IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); + OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); + OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); + + const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); + const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); + const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); + const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); + const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + // Prefetch data of the first stage + + if constexpr (IS_BWD) { + copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, + TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, + chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } else { + copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, + TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], + is_master_thread); + } + +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + const size_t buff = it % BUFFERS_NUM; + const size_t next_it = it + 1; + if (next_it < ITERATIONS) { + const size_t next_buff = next_it % BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_BWD) { + copy_2d_to_sharedx3( + &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, + &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, + &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, + in_transaction_size, &mbar[next_it], is_master_thread); + } else { + copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, + chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, + chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, + &mbar[next_it], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[it], parity); + + IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; + IType *in_act_sh_curr = in_act_sh + buff * buff_elems; + IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; + OType *out_act_sh_curr = out_act_sh + buff * buff_elems; + OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; +#pragma unroll + for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + + float act_elt = static_cast(in_act_sh_curr[shmem_idx]); + float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1; + } + + if constexpr (IS_BWD) { + float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); + + const float x = act_elt; + float act_x; + float dact_x; + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); + act_x = x * s; + if (act_elt <= p.limit) { + dact_x = s + s * (1 - s) * p.alpha * x; + } else { + dact_x = 0.0f; + } + } else { + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } + } + float after_dact = dact_x * grad_elt * gate_elt; + float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f; + + out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); + out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); + + amax = fmaxf(amax, fabsf(after_dact)); + amax = fmaxf(amax, fabsf(after_dgate)); + } else { + const float after_act = ActOP(act_elt, p) * gate_elt; + out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); + amax = fmaxf(amax, fabsf(after_act)); + } + } + + // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + + // dGeLU + ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, + chunk_it_offset_y, + reinterpret_cast(out_act_sh_curr)); + + if constexpr (IS_BWD) { + // dGate + ptx::cp_async_bulk_tensor_2d_shared_to_global( + TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, + reinterpret_cast(out_gate_sh_curr)); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + // Destroy the barriers. This invalidates the memory region of the barrier. + // If further computations were to take place in the kernel, this allows the + // memory location of the shared memory barrier to be reused. + if (is_master_thread) { +#pragma unroll + for (int it = 0; it < ITERATIONS; ++it) { + ptx::mbarrier_invalid(&mbar[it]); + } + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace kernel + +template +void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, + cudaStream_t stream) { + using namespace kernel; + checkCuDriverContext(stream); + + NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block_dim(THREADS_PER_CHUNK); + const dim3 grid_dim(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act{}; + alignas(64) CUtensorMap tensor_map_output_gate{}; + + if constexpr (IS_BWD) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, + cols, 0, typeToNumBits(gated_input.dtype())); + } + + const uint32_t tensor_stride_elems = output_cols; + + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); + create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); + create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, + SHMEM_DIM_X, tensor_stride_elems, cols, + typeToNumBits(output->dtype())); + + const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = buff_size_aligned_out; + + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; + + auto kernel = cast_fp8_gated_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, + tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +} +#endif //#ifndef __HIP_PLATFORM_AMD__ + +template +void cast_gated_fwd(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + GatedActivationKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), + output->flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} + +template +void cast_gated_bwd(const Tensor &input, const Tensor &grad, Tensor *output, ParamOP &p, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->dtype(), OType, + + constexpr int nvec = 32 / sizeof(IType); + DGatedActivationKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input.data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), + grad.flat_last_dim(), p, stream);); // NOLINT(*) + ); // NOLINT(*) +} +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GATED_FP8_CUH_ diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh new file mode 100644 index 000000000..9de093e96 --- /dev/null +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -0,0 +1,739 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_fp8.cuh + * \brief CUDA kernels to quantize to FP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ + +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif //#ifndef __HIP_PLATFORM_AMD__ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../transpose/cast_transpose.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../util/vectorized_pointwise.h" +#include "../../utils.cuh" +#include "../core/common.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace fp8 { +#ifdef __HIP_PLATFORM_AMD__ +constexpr size_t TILE_DIM = 32; +template +__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { + __shared__ float tile[TILE_DIM][TILE_DIM]; + + int tile_start_col = blockIdx.x * TILE_DIM; + int tile_start_row = blockIdx.y * TILE_DIM; + int thread_col_in_tile = threadIdx.x; + int thread_row_in_tile = threadIdx.y; + + int global_col = tile_start_col + thread_col_in_tile; + int global_row = tile_start_row + thread_row_in_tile; + + if (global_row < rows && global_col < cols) { + tile[thread_row_in_tile][thread_col_in_tile] = static_cast(input[global_row * cols + global_col]); + } else { + tile[thread_row_in_tile][thread_col_in_tile] = 0.0f; + } + __syncthreads(); + + for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { + if (thread_row_in_tile < stride) { + tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; + } + __syncthreads(); + } + + if (thread_row_in_tile == 0 && global_col < cols) { + partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile]; + } +} + +template +void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, + const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) { + dim3 block_dim_partial(TILE_DIM, TILE_DIM); + dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); + + const size_t partial_rows = grid_dim_partial.y; + float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); + + partial_reduce_kernel<<>>( + workspace_ptr, + partial_workspace, + rows, cols); + + common::reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); +} + + +#else +namespace quantize_2D_kernel { + +constexpr size_t FP8_CHUNK_DIM_Y = 128; +constexpr size_t FP8_CHUNK_DIM_X = 128; +constexpr size_t FP8_THREADS_PER_CHUNK = 128; +constexpr size_t FP8_BUFFERS_NUM = 2; +constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; +static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); + +constexpr size_t FP8_BUFFER_DIM_Y = 16; +constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 +constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 + +constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 +constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 +static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); + +template +__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) + cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output, + float *const dbias_workspace, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, + const size_t cols) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + + const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; + + const size_t dbias_offset_Y = blockIdx.y + tid_Y; + const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const bool col_out_of_bounds = my_column >= cols; + const size_t dbias_stride = cols; + + float partial_dbias = 0.f; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + + constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; + +#pragma unroll + for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { + const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const size_t chunk_stage_offset_X = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, + chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, + &mbar[prefetch_buff], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, + chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], + is_master_thread); + } + } + +#pragma unroll + for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { + const size_t buff = iter % FP8_BUFFERS_NUM; + const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; + if (next_iter < FP8_ITERATIONS) { + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, + chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, + chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + } + } + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { + const size_t stage_offset_Y = stage; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t row = row_base + shmem_offset_y; + const bool row_out_of_bounds = row >= rows; + const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; + + float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + if constexpr (IS_DACT) { + if (!out_of_bounds) { + partial_dbias += elt; + } + } else { + // If no activation, elt is 0 so we can safely do this + partial_dbias += elt; + } + } + __builtin_assume(amax >= 0); + if (IS_DACT) { + if (!out_of_bounds) { + amax = fmaxf(amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + amax = fmaxf(amax, fabsf(elt)); + } + out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), chunk_it_offset_x, + chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + parity ^= 1; + + if constexpr (IS_DBIAS) { + const size_t dbias_offset_X = my_column; + const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + if (!col_out_of_bounds) { + dbias_workspace[dbias_offset] = partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_2D_kernel + +namespace quantize_1D_kernel { +using namespace quantize_2D_kernel; + +constexpr size_t CHUNKS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; +constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; +constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; +constexpr size_t CHUNKS_PER_ITERATION = 32; +constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; +constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; +constexpr size_t SHMEM_BUFFERS = 2; +static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, + float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const IType *input = input_ptr + block_offset; + OType *output = output_ptr + block_offset; + + float amax = 0; + const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; + + // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + + constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + + const bool is_master_thread = (threadIdx.x == 0); + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[ITERATIONS]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); + +#pragma unroll + for (int iter = 0; iter < ITERATIONS; ++iter) { + const size_t buff = iter % SHMEM_BUFFERS; + const size_t it_offset = iter * SHMEM_DIM; + + const size_t next_iter = iter + 1; + const size_t next_buff = next_iter % SHMEM_BUFFERS; + const size_t next_iter_offset = next_iter * SHMEM_DIM; + + if (next_iter < ITERATIONS) { + copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, + &(mbar[next_iter]), is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[iter], parity); + +#pragma unroll + for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { + const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + float elt = static_cast(in_sh[buff][shmem_offset]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(elt)); + out_sh[buff][shmem_offset] = static_cast(elt * scale); + } + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + ptx::cp_async_bulk_tensor_1d_shared_to_global( + reinterpret_cast(output + it_offset), + reinterpret_cast(&out_sh[buff]), transaction_size_OUT); + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + + // Wait for TMA transfer to have finished reading shared memory. + ptx::cp_async_bulk_wait_group_read<1>(); + } + } + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + amax = reduce_max(amax, warp_id); + // Update the global amax + if (is_master_thread) { + atomicMaxFloat(amax_ptr, amax); + } + } + + // Update scale-inverse + if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { + reciprocal(scale_inv_ptr, scale); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_1D_kernel + +template +void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace quantize_1D_kernel; + const size_t N = product(input.data.shape); + + const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); + NVTE_CHECK(isFullTile, "Only full tiles are supported."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + const size_t chunks = DIVUP(N, CHUNK_SIZE); + const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + const float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(THREADS_PER_BLOCK); + const dim3 grid(blocks); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + const IType *input_ptr = reinterpret_cast(input.data.dptr); + OType *output_ptr = reinterpret_cast(output->data.dptr); + + cast_fp8_1D_kernel<<>>( + input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) + ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, + Tensor *workspace, cudaStream_t stream) { + using namespace quantize_2D_kernel; + checkCuDriverContext(stream); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); + const size_t blocks_Y = chunks_Y; + const size_t blocks_X = chunks_X; + + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); + float *const scale_ptr = reinterpret_cast(output->scale.dptr); + + const dim3 block(FP8_THREADS_PER_CHUNK); + const dim3 grid(blocks_X, blocks_Y); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->data.dtype, OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); + } + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, + FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); + + cast_fp8_2D_kernel + <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, + workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, + cols); + NVTE_CHECK_CUDA(cudaGetLastError()); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} +#endif //#ifdef __HIP_PLATFORM_AMD__ + +namespace detail { +using Empty = transformer_engine::Empty; +__device__ inline float identity(float value, const Empty &) { return value; } +} // namespace detail + +/* HIPCC has strict rules for __device__ functions usage on host. + It forbids not only calling but also other ODR-use assigning to variables + https://github.com/llvm/llvm-project/issues/105825 + Use templated struct wrapper to work around + */ +template +struct ActivationType +{ + static constexpr auto op = OP; +}; + + +template +void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; +#else //#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; +#endif //#ifdef __HIP_PLATFORM_AMD__ + const size_t N = product(input.data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryKernelLauncher( + reinterpret_cast(input.data.dptr), + reinterpret_cast(noop->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, + cudaStream_t stream) { +#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; +#else //#ifdef __HIP_PLATFORM_AMD__ + constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; +#endif //#ifdef __HIP_PLATFORM_AMD__ + const size_t N = product(input->data.shape); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input->data.dtype, IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + output->data.dtype, OType, + if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { + constexpr int nvec = 32 / sizeof(IType); + VectorizedUnaryGradKernelLauncher( + reinterpret_cast(grad.data.dptr), + reinterpret_cast(input->data.dptr), + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->scale.dptr), + reinterpret_cast(output->amax.dptr), + reinterpret_cast(output->scale_inv.dptr), N, {}, stream); + } else { + NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); + }); // NOLINT(*) + ); // NOLINT(*) +} + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, + Tensor *dbias, Tensor *workspace, cudaStream_t stream) { +#ifndef __HIP_PLATFORM_AMD__ + using namespace quantize_1D_kernel; +#endif //#ifndef __HIP_PLATFORM_AMD__ + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "cast_input"); + CheckOutputTensor(*output, "cast_output"); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias != nullptr); + CheckOutputTensor(*dbias, "dbias"); + } + if constexpr (IS_DACT) { + NVTE_CHECK(act_input != nullptr); + CheckInputTensor(*act_input, "activation_input"); + NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); + NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); + } + + NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + +#ifdef __HIP_PLATFORM_AMD__ + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + if constexpr (IS_DBIAS) { + NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); + if (workspace->data.dptr == nullptr) { + if constexpr (IS_DACT) { + const size_t partial_rows = DIVUP(rows, TILE_DIM); + size_t total_elements = (rows * cols) + (partial_rows * cols); + workspace->data.shape = {total_elements}; + workspace->data.dtype = DType::kFloat32; + } else { + workspace->data.shape = {rows, cols}; + workspace->data.dtype = DType::kFloat32; + } + return; + } + + const void *ptr_to_reduce = nullptr; + DType dtype_to_reduce; + + workspace->amax = {}; + workspace->scale = {}; + workspace->scale_inv = {}; + + Tensor workspace_buffer; + Tensor partial_sum_buffer; + + if constexpr (IS_DACT) { + // The values to reduce are the result of the dAct function. + NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); + + const size_t partial_rows = DIVUP(rows, TILE_DIM); + const size_t full_size_bytes = rows * cols * sizeof(float); + workspace_buffer = *workspace; + workspace_buffer.data.shape = {rows, cols}; + partial_sum_buffer.data.dptr = reinterpret_cast(workspace->data.dptr) + full_size_bytes; + partial_sum_buffer.data.shape = {partial_rows, cols}; + partial_sum_buffer.data.dtype = DType::kFloat32; + workspace = &partial_sum_buffer; + + CastVectorizedUnaryGradKernelLauncher(input, act_input, &workspace_buffer, stream); + if (output && output->data.dptr) { + CastVectorizedUnaryKernelLauncher(workspace_buffer, noop, output, stream); + } + ptr_to_reduce = workspace_buffer.data.dptr; + dtype_to_reduce = workspace_buffer.data.dtype; + } else { + if (output && output->data.dptr) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + // The values to reduce are just the input values. + ptr_to_reduce = input.data.dptr; + dtype_to_reduce = input.data.dtype; + } + + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias tensor."); + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dbias->data.dtype, DBiasTypeOut, + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + dtype_to_reduce, DTypeReduce, + reduce_dbias_rocm( + reinterpret_cast(ptr_to_reduce), + dbias, rows, cols, stream, workspace); + ); + ); + } else { + if (output && output->data.dptr) { + if constexpr (IS_DACT) { + NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output."); + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } else { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } + } +#else + // Supported by the Arch >= 10.0 + if (is_supported_by_CC_100()) { + if (!IS_DBIAS && !IS_DACT) { + if (common::full_tile_1D_tensor(output, ELEMS_PER_BLOCK) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 + quantize_1D(input, output, stream); + } else { + // Unaligned + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } + } else if (!IS_DBIAS && IS_DACT) { + if (common::dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { + // Aligned AND FP8 (+dAct) + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } else { + // Unaligned + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } else { + quantize_2D(input, act_input, output, dbias, workspace, + stream); + } + } else { + if (IS_DBIAS) { + // zhongboz: should we just ignore IS_ACT here? + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + + " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); + } + if (!IS_DACT) { + CastVectorizedUnaryKernelLauncher(input, noop, output, stream); + } else { + CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); + } + } +#endif //#ifdef __HIP_PLATFORM_AMD__ +} + +} // namespace fp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_FP8_CUH_ diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh similarity index 82% rename from transformer_engine/common/util/dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index aaeb169b1..38eead606 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -1,45 +1,39 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ -/*! \file dequantize_kernels.cuh - * \brief CUDA kernels to cast from MXFP8. +/*! \file dequantize_mxfp8.cuh + * \brief CUDA kernels to dequantize from MXFP8. */ -#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ #include #ifndef __HIP_PLATFORM_AMD__ #include #endif //#ifndef __HIP_PLATFORM_AMD__ #include -#include - -#include -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/transpose.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_dequantize_kernels.cuh" -#endif +#include -namespace transformer_engine { +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" -namespace dequantization { +#include "./rocm_vectorized_2d.cuh" -#ifndef __HIP_PLATFORM_AMD__ +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace dequantize_kernel { +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_dequantize_mxfp8.cuh" +#else template __global__ void __launch_bounds__(THREADS_PER_CHUNK) dequantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, @@ -218,29 +212,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } #endif // #ifndef __HIP_PLATFORM_AMD__ +} // namespace dequantize_kernel -static void fp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { - NVTE_CHECK(is_fp8_dtype(input.data.dtype), "Input must have FP8 type."); - NVTE_CHECK(!is_fp8_dtype(output->data.dtype), "Output must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - output->data.dtype, OType, - - constexpr int nvec = 32 / sizeof(OType); - detail::DequantizeParam p; - p.scale_inv = reinterpret_cast(input.scale_inv.dptr); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), nullptr, - reinterpret_cast(output->data.dptr), nullptr, nullptr, nullptr, N, p, - stream);); // NOLINT(*) - ); // NOLINT(*) -} - -static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { + using namespace dequantize_kernel; bool use_rowwise_scaling = input.has_data(); bool use_colwise_scaling = input.has_columnwise_data(); #ifndef __HIP_PLATFORM_AMD__ @@ -336,34 +311,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s #endif NVTE_CHECK_CUDA(cudaGetLastError()); } -} // namespace dequantization - -namespace detail { - -void dequantize_helper(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if (is_tensor_scaling(input.scaling_mode)) { - dequantization::fp8_dequantize(input, output, stream); - } else if (is_mxfp_scaling(input.scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - if (1) { -#else - if (is_supported_by_CC_100()) { -#endif - dequantization::mxfp8_dequantize(input, output, stream); - } else { - NVTE_ERROR("MXFP8 Dequantization is NOT supported by architectures < 10.0"); - } - } else { - // TODO(kwyss): Move dequantization code from torch to C++ for NVTE_BLOCK_SCALING - NVTE_ERROR("Not implemented scaling mode: " + to_string(input.scaling_mode) + "."); - } -} - -} // namespace detail +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_DEQUANTIZE_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh similarity index 52% rename from transformer_engine/common/util/cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index dcb3aa42d..69e30680c 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -6,272 +6,34 @@ * See LICENSE for license information. ************************************************************************/ -/*! \file cast_gated_kernels.cuh - * \brief CUDA gated activations kernels to cast to/from FP8/MXFP8. +/*! \file gated_mxfp8.cuh + * \brief CUDA kernels to cast to MXFP8 with gated activations. */ -#ifndef TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#ifndef TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ #include #ifndef __HIP_PLATFORM_AMD__ #include #endif //#ifndef __HIP_PLATFORM_AMD__ #include -#include -#include +#include -#include +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_cast_gated_kernels.cuh" -#endif +#include "./rocm_vectorized_2d.cuh" namespace transformer_engine { - -namespace gated_kernels { - -#ifndef __HIP_PLATFORM_AMD__ -constexpr size_t CHUNK_DIM_Y = 128; -constexpr size_t CHUNK_DIM_X = 128; -constexpr size_t THREADS_PER_CHUNK = 512; -constexpr size_t THREADS_PER_CHUNK_X = CHUNK_DIM_X; -constexpr size_t THREADS_PER_CHUNK_Y = THREADS_PER_CHUNK / THREADS_PER_CHUNK_X; // 4 = 512 / 128 -constexpr size_t BUFFERS_NUM = 2; -constexpr size_t BUFFER_DIM_Y = 32; -constexpr size_t BUFFER_DIM_X = CHUNK_DIM_X; // 128 -constexpr size_t SHMEM_DIM_Y = BUFFER_DIM_Y; // 32 -constexpr size_t SHMEM_DIM_X = BUFFER_DIM_X; // 128 - -constexpr size_t BUFFER_STAGES_NUM = BUFFER_DIM_Y / THREADS_PER_CHUNK_Y; // 8 = 32 / 4 -constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 4 = 128 / 32 -static_assert(ITERATIONS >= 1); - -__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_fp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act, - const __grid_constant__ CUtensorMap tensor_map_output_gate, - float *const amax_ptr, float *const scale_inv_ptr, - const float *const scale_ptr, const size_t rows, const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; - - constexpr size_t in_act_mem = buff_size_aligned_in; - constexpr size_t in_gate_mem = buff_size_aligned_in; - constexpr size_t in_mem = in_act_mem + in_gate_mem; - - constexpr size_t out_act_mem = buff_size_aligned_out; - constexpr size_t in_transaction_size = buff_elems * sizeof(IType); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_grad_sh = reinterpret_cast(dshmem); - IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); - IType *in_gate_sh = reinterpret_cast(dshmem + grad_mem + in_act_mem); - OType *out_act_sh = reinterpret_cast(dshmem + grad_mem + in_mem); - OType *out_gate_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_act_mem); - - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act = reinterpret_cast(&tensor_map_output_act); - const uint64_t *TMAP_output_gate = reinterpret_cast(&tensor_map_output_gate); - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - // Prefetch data of the first stage - - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3(in_grad_sh, TMAP_grad_in, chunk_offset_X, chunk_offset_Y, in_act_sh, - TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, TMAP_in_gate, - chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } else { - copy_2d_to_sharedx2(in_act_sh, TMAP_in_act, chunk_offset_X, chunk_offset_Y, in_gate_sh, - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, in_transaction_size, &mbar[0], - is_master_thread); - } - -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const size_t buff = it % BUFFERS_NUM; - const size_t next_it = it + 1; - if (next_it < ITERATIONS) { - const size_t next_buff = next_it % BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DGATED) { - copy_2d_to_sharedx3( - &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, - &in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, chunk_it_offset_y, - &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, chunk_it_offset_x, chunk_it_offset_y, - in_transaction_size, &mbar[next_it], is_master_thread); - } else { - copy_2d_to_sharedx2(&in_act_sh[next_buff * buff_elems], TMAP_in_act, chunk_it_offset_x, - chunk_it_offset_y, &in_gate_sh[next_buff * buff_elems], TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, in_transaction_size, - &mbar[next_it], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); - - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_sh_curr = out_act_sh + buff * buff_elems; - OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; - -#pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); - - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - - const float x = act_elt; - float act_x; - float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; - } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); - } - - float after_dact = dact_x * grad_elt * gate_elt; - float after_dgate = act_x * grad_elt; - - out_act_sh_curr[shmem_idx] = static_cast(scale * after_dact); - out_gate_sh_curr[shmem_idx] = static_cast(scale * after_dgate); - - amax = fmaxf(amax, fabsf(after_dact)); - amax = fmaxf(amax, fabsf(after_dgate)); - } else { - const float after_act = ActOP(act_elt, {}) * gate_elt; - out_act_sh_curr[shmem_idx] = static_cast(scale * after_act); - amax = fmaxf(amax, fabsf(after_act)); - } - } - - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - - // dGeLU - ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, - chunk_it_offset_y, - reinterpret_cast(out_act_sh_curr)); - - if constexpr (IS_DGATED) { - // dGate - ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_sh_curr)); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -namespace mxfp8_kernel { - +namespace dispatch { +namespace mxfp8 { +namespace gated_kernel { +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_gated_mxfp8.cuh" +#else constexpr size_t CHUNK_DIM_Y = 64; constexpr size_t CHUNK_DIM_X = 64; constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; @@ -295,20 +57,21 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 -template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, - const __grid_constant__ CUtensorMap tensor_map_input_act, - const __grid_constant__ CUtensorMap tensor_map_input_gate, - const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, - const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_gated_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, + const __grid_constant__ CUtensorMap tensor_map_input_act, + const __grid_constant__ CUtensorMap tensor_map_input_gate, + const __grid_constant__ CUtensorMap tensor_map_output_act_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_act_colwise, + const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, + const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; @@ -378,14 +141,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -420,7 +183,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) int parity = 0; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, @@ -447,7 +210,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; const size_t global_offset_X = block_offset_X; const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], @@ -484,43 +247,55 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); float after_act_elt; float after_gate_elt; - - if constexpr (IS_DGATED) { + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_elt = static_cast(static_cast(after_gate_elt)); } } after_act_colwise[i] = after_act_elt; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_colwise[i] = after_gate_elt; } // Cache computed activations to avoid computing them again in the 2nd pass along another dimension if constexpr (IS_CACHED_ACT_OP) { cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); } } @@ -530,7 +305,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if (!out_of_bounds) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); } } @@ -559,7 +334,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // All threads read the reduced amax (ACT) thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { // Make sure the previous read of the ACT values has been completed, // so the data are not rewritten __syncthreads(); @@ -603,9 +378,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); float block_scale_inverse_gate; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { @@ -619,7 +395,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { const size_t shmem_offset_elt = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { OType2 out_pair; ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, @@ -665,7 +441,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Load cached elements in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); } // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) @@ -675,7 +451,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); } } @@ -685,7 +461,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], in_cached_act[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], in_cached_gate[w].data.elt[e + 1]}; ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); @@ -697,7 +473,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (!std::is_same_v) { thread_amax_act = static_cast( __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = static_cast( __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); } @@ -715,7 +491,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) in_act.load_from(&in_act_sh[shmem_offset_rowwise]); in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); } @@ -727,34 +503,46 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float gate_elt = static_cast(in_gate.data.elt[e]); float after_act_elt; float after_gate_elt; - - if constexpr (IS_DGATED) { + bool dgate_elt = true; + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + if constexpr (IS_BWD) { float grad_elt = static_cast(in_grad.data.elt[e]); const float x = act_elt; float act_x; float dact_x; - - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; + after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f; after_act_rowwise[j] = after_act_elt; after_gate_rowwise[j] = after_gate_elt; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, p) * gate_elt; after_act_rowwise[j] = after_act_elt; } // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 if constexpr (!std::is_same_v) { after_act_elt = static_cast(static_cast(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { after_gate_elt = static_cast(static_cast(after_gate_elt)); } } @@ -764,7 +552,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); } } @@ -790,7 +578,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float block_scale_inverse_gate; ptx::floatx2 block_scale_inverse_2x_gate; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; @@ -821,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { IType2 in_gate; OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); @@ -836,11 +624,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); } } @@ -861,7 +650,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); @@ -871,7 +660,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); @@ -887,95 +676,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } -} // namespace mxfp8_kernel +#endif //#ifndef __HIP_PLATFORM_AMD__ +} // namespace gated_kernel -template -void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, +void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *output, ParamOP &p, cudaStream_t stream) { - checkCuDriverContext(stream); - - if (output->has_data()) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); - } - - NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act{}; - alignas(64) CUtensorMap tensor_map_output_gate{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, - cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, 0, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols * 2, cols, typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_output_act, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate, output->data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - const size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - cast_fp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, - tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) -} -#endif //#ifdef __HIP_PLATFORM_AMD__ - -template -void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { + using namespace gated_kernel; checkCuDriverContext(stream); const bool USE_ROWWISE_SCALING = output->has_data(); @@ -1001,7 +709,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; #ifdef __HIP_PLATFORM_AMD__ constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; @@ -1009,20 +717,12 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; constexpr size_t BUFFS_NUM = BUFFERS_NUM; +#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); -#else - - constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; - constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; - constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; - const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); - - constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; - constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; +#ifndef __HIP_PLATFORM_AMD__ const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; @@ -1045,7 +745,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ - const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *tensor_map_grad = IS_BWD ? reinterpret_cast(grad.data.dptr) : nullptr; const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; @@ -1067,7 +767,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out constexpr size_t input_type_bit_size = TypeInfo::size; constexpr size_t output_type_bit_size = TypeInfo::size; - if constexpr (IS_DGATED) { + if constexpr (IS_BWD) { create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, input_type_bit_size); } @@ -1105,7 +805,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t buff_size_aligned_out = DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t grad_mem = (IS_BWD ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; @@ -1114,9 +814,10 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out #ifdef __HIP_PLATFORM_AMD__ const size_t out_gate_mem = buff_size_aligned_out; #else - const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + const size_t out_gate_mem = (IS_BWD ? buff_size_aligned_out : 0); #endif size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; @@ -1128,257 +829,72 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, + quantize_gated_mxfp8_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - cast_mxfp8_gated_kernel + quantize_gated_mxfp8_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + scale_stride_colwise, p); NVTE_CHECK_CUDA(cudaGetLastError()); }))); // NOLINT(*) #else switch (scaling_type) { - case ScalingType::ROWWISE: + case ScalingType::ROWWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::COLWISE: + } + case ScalingType::COLWISE: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - case ScalingType::BIDIMENSIONAL: + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_gated_mxfp8_kernel; NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); - - mxfp8_kernel::cast_mxfp8_gated_kernel - <<>>( - tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, - tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, - tensor_map_output_act_colwise, tensor_map_output_gate_colwise, - scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + kernel<<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, + scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise, p); break; - } + } + } NVTE_CHECK_CUDA(cudaGetLastError()); // NOLINT(*) #endif ); // NOLINT(*) ); // NOLINT(*) } -template -void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(input, "gated_act_input"); - CheckOutputTensor(*output, "gated_act_output"); - NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", input.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.flat_last_dim() % 2 == 0, - "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", - input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == input.flat_last_dim() / 2, - "Wrong output shape. Expected (after flattening) [*, ", input.flat_last_dim() / 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - GatedActivationKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), input.flat_first_dim(), - output->flat_last_dim(), {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { - CheckInputTensor(grad, "dgated_act_grad"); - CheckInputTensor(input, "dgated_act_input"); - CheckOutputTensor(*output, "dgated_act_output"); - NVTE_CHECK(output->flat_first_dim() == grad.flat_first_dim(), - "Wrong output shape. Expected (after flattening) [", grad.flat_first_dim(), - ", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(output->flat_last_dim() == grad.flat_last_dim() * 2, - "Wrong output shape. Expected (after flattening) [*, ", grad.flat_last_dim() * 2, - "], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "]."); - NVTE_CHECK(input.data.shape == output->data.shape, - "Input and output shapes must match. Input shape: ", input.data.shape, - ", output shape: ", output->data.shape, "."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->dtype(), OType, - - if (!is_fp8_dtype(output->data.dtype) || - is_delayed_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - DGatedActivationKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input.data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), grad.flat_first_dim(), - grad.flat_last_dim(), {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, - cudaStream_t stream) { - constexpr bool allow_empty = false; - CheckInputTensor(gated_input, "gated_input"); - CheckOutputTensor(*output, "output", allow_empty); - - NVTE_CHECK(gated_input.flat_last_dim() % 2 == 0, "Number of columns must be even."); - - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - - if constexpr (IS_DGATED) { - CheckInputTensor(grad, "grad"); - NVTE_CHECK(!is_fp8_dtype(grad.data.dtype), "Grad input must be in higher precision."); - NVTE_CHECK(grad.data.dtype == gated_input.data.dtype, "Types of both inputs must match."); - NVTE_CHECK(grad.flat_first_dim() == rows, "Wrong dimension of the grad input."); - NVTE_CHECK(grad.flat_last_dim() == cols, "Wrong dimension of the grad input."); - } - - NVTE_CHECK(output->has_data() || output->has_columnwise_data(), - "Either rowwise or columnwise output data need to be allocated."); - - bool is_fp8_rowwise_output = true; - bool is_fp8_colwise_output = true; - if (output->has_data()) { - is_fp8_rowwise_output = is_fp8_dtype(output->data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - if (output->has_columnwise_data()) { - is_fp8_colwise_output = is_fp8_dtype(output->columnwise_data.dtype); - NVTE_CHECK(output->flat_first_dim() == rows, "Wrong dimension of the output."); - NVTE_CHECK(output->flat_last_dim() == output_cols, "Wrong dimension of the output."); - } - - const bool use_tma_kernels = is_fp8_rowwise_output && is_fp8_colwise_output && cols % 32 == 0; - - if (is_delayed_tensor_scaling(output->scaling_mode)) { -#ifdef __HIP_PLATFORM_AMD__ - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } -#else - if (use_tma_kernels) { - cast_fp8_gated(grad, gated_input, output, stream); - } else { - if constexpr (IS_DGATED) { - cast_dgated(grad, gated_input, output, stream); - } else { - cast_gated(gated_input, output, stream); - } - } -#endif - } else if (is_mxfp_scaling(output->scaling_mode)) { - if (use_tma_kernels) { - cast_mxfp8_gated(grad, gated_input, output, stream); - } else { - NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", - "by 32, got input of shape ", gated_input.data.shape); - } - } else { - NVTE_ERROR("Not supported scaling mode"); - } -} -} // namespace gated_kernels - -namespace detail { - -template -void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, - cudaStream_t stream) { - using namespace gated_kernels; - Tensor grad_empty_tensor; - const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; - const Tensor gated_input_tensor = *convertNVTETensorCheck(gated_input); - Tensor *output_tensor = convertNVTETensorCheck(output); - -#ifdef __HIP_PLATFORM_AMD__ - if (1) { -#else - if (is_supported_by_CC_100()) { -#endif - quantize_gated(grad_tensor, gated_input_tensor, - output_tensor, stream); - } else { - if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { - if constexpr (IS_DGATED) { - cast_dgated(grad_tensor, gated_input_tensor, output_tensor, stream); - } else { - cast_gated(gated_input_tensor, output_tensor, stream); - } - } else { - // MX scaling - NVTE_ERROR("Not supported by the Arch < 10.0"); - } - } -} -} // namespace detail - +} // namespace mxfp8 +} // namespace dispatch } // namespace transformer_engine -#endif // TRANSFORMER_ENGINE_CAST_GATED_KERNELS_CUH_ +#endif // TRANSFORMER_ENGINE_GATED_MXFP8_CUH_ diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh new file mode 100644 index 000000000..8e25b3f65 --- /dev/null +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -0,0 +1,762 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_mxfp8.cuh + * \brief CUDA kernels to quantize to MXFP8. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ + +#include +#ifndef __HIP_PLATFORM_AMD__ +#include +#endif //#ifndef __HIP_PLATFORM_AMD__ +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "../core/common.cuh" + +#include "./rocm_vectorized_2d.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace mxfp8 { +namespace quantize_kernel { +#ifdef __HIP_PLATFORM_AMD__ +#include "rocm_quantize_mxfp8.cuh" +#else +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_mxfp8_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_act_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + + OType *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; + if constexpr (IS_DBIAS) { +#pragma unroll + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; + } + } + + float block_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + int parity = 0; + + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } else { + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); + } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in.data.elt[e]); + elt *= OP(act_in_elt, {}); + } + + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + if (rowwise_scale_is_within_bounds) { + scales_rowwise[scale_idx] = biased_exponent; + } + + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + + parity ^= 1; + + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); + + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; + } + } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } + } + + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + block_amax = reduce_max(block_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, block_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +#endif //#ifndef __HIP_PLATFORM_AMD__ +} // namespace quantize_kernel + +template +void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { + using namespace quantize_kernel; +#ifndef __HIP_PLATFORM_AMD__ + checkCuDriverContext(stream); +#endif + + bool use_rowwise_scaling = output->has_data(); + bool use_colwise_scaling = output->has_columnwise_data(); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); + + if (use_rowwise_scaling) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; + constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; + constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; +#else + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +#endif + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + e8m0_t *const scales_rowwise_ptr = + use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; + e8m0_t *const scales_colwise_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; + +#ifndef __HIP_PLATFORM_AMD__ + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } +#endif + + if constexpr (IS_DBIAS) { + NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); + NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + + if (workspace->data.dptr == nullptr) { + workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.dtype = DType::kFloat32; + return; + } + } + + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + +#ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + !(cols % (32 * sizeof(IType))), IS_ALIGNED, + quantize_mxfp8_kernel + <<>>( + reinterpret_cast(input.data.dptr), + (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, + reinterpret_cast(output->data.dptr), + reinterpret_cast(output->columnwise_data.dptr), + scales_rowwise_ptr, scales_colwise_ptr, + reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, + rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))); // NOLINT(*) +#else // #ifdef __HIP_PLATFORM_AMD__ + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::COLWISE: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_mxfp8_kernel; + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + break; + } + } +#endif // #ifdef __HIP_PLATFORM_AMD__ + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) +} + +} // namespace mxfp8 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_MXFP8_CUH_ diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh similarity index 86% rename from transformer_engine/common/util/rocm_dequantize_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh index 398e4c0ad..49c57737c 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_dequantize_mxfp8.cuh @@ -5,26 +5,7 @@ ************************************************************************/ #pragma once - -#include -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/cast.h" -#include "transpose/cast_transpose.h" -#include "transformer_engine/transpose.h" -#include "utils.cuh" -#include "vectorized_pointwise.h" - -namespace transformer_engine { - -namespace dequantization { +// drop-in rocm replacement for mxfp8 dequantize kernel constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; @@ -86,7 +67,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y; const int chunk_it_offset_x = chunk_offset_X; - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -127,12 +108,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); - bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, - chunk_it_offset_y, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, rows, cols); + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_sh[0][0], output_ptr, chunk_it_offset_x, + chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); } } -} // namespace dequantization -} // namespace transformer_engine + diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh similarity index 84% rename from transformer_engine/common/util/rocm_cast_gated_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh index a53fd51c5..a8c02e4f8 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_gated_mxfp8.cuh @@ -5,22 +5,7 @@ ************************************************************************/ #pragma once - -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/activation.h" -#include "transformer_engine/cast.h" -#include "vectorized_pointwise.h" -#include "utils.cuh" - -namespace transformer_engine { -namespace gated_kernels { +// drop-in rocm replacement for mxfp8 gated quantize kernel constexpr size_t ALIGNMENT_SIZE = 128; // TODO: Identify optimal chunk/thread size for MI350+ @@ -45,16 +30,17 @@ template __global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_gated_kernel(const IType *grad_ptr, - const IType *input_act, - const IType *input_gate, - OType *output_act_rowwise, - OType *output_gate_rowwise, - OType *output_act_colwise, - OType *output_gate_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_gated_mxfp8_kernel( + const IType *grad_ptr, + const IType *input_act, + const IType *input_gate, + OType *output_act_rowwise, + OType *output_gate_rowwise, + OType *output_act_colwise, + OType *output_gate_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise, const ParamOP p) { constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; constexpr bool COMPUTE_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; @@ -136,16 +122,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate bulk tensor copy if constexpr (IS_DGATED) { - copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } // Act - copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); // Gate - copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, + transformer_engine::rocm::copy_2d_to_shared(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y, 2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -171,24 +157,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh[shmem_idx]); float gate_elt = static_cast(in_gate_sh[shmem_idx]); + bool dgate_elt = true; // gating is ideally an identity function + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp + gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f; + } + if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_idx]); const float x = act_elt; float act_x; float dact_x; - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); + if constexpr (std::is_same::value) { + const float x = min(act_elt, p.limit); + const float s = sigmoidf(p.alpha * x); act_x = x * s; - dact_x = x * s * (1 - s) + s; + dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, p); + dact_x = DActOP(x, p); + } } + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; + after_dgate_reg[stage] = dgate_elt ? act_x * grad_elt : 0.0f; } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + after_dact_reg[stage] = ActOP(act_elt, p) * gate_elt; } // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 @@ -355,24 +356,22 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); if constexpr (IS_DGATED) { - bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols); } } __syncthreads(); } } -} // namespace gated_kernels -} // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh similarity index 65% rename from transformer_engine/common/util/rocm_cast_kernels.cuh rename to transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh index e39e0a4a7..d5b51a2f4 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh @@ -4,28 +4,7 @@ * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once - -#include -#include -#include - -#include "common.h" -#include "math.h" -#include "ptx.cuh" -#include "rocm_vectorized_2d.cuh" -#include "transformer_engine/cast.h" -#include "transpose/cast_transpose.h" -#include "vectorized_pointwise.h" -#include "utils.cuh" - -namespace transformer_engine { - -// Forward declaration, definition is in cast_kernels.cuh -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); - +// drop-in replacement for rocm quantize_mxfp8 kernels constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; @@ -53,14 +32,15 @@ template __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const IType *input_ptr, - const IType *act_input_ptr, - OType *output_rowwise, - OType *output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { + quantize_mxfp8_kernel( + const IType *input_ptr, + const IType *act_input_ptr, + OType *output_rowwise, + OType *output_colwise, + e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, + const float *noop, float *const dbias_workspace, float *const amax_ptr, + const size_t rows, const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { if (noop != nullptr && noop[0] == 1.0f) return; } @@ -163,11 +143,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const int chunk_it_offset_x = chunk_offset_X; const size_t row_base = chunk_it_offset_y; if constexpr (IS_DACT) { - copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, + transformer_engine::rocm::copy_2d_to_shared(&act_in_sh[0][0], act_input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } - copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, + transformer_engine::rocm::copy_2d_to_shared(&in_sh[0][0], input_ptr, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); __syncthreads(); @@ -310,12 +290,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __syncthreads(); if constexpr (USE_ROWWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } if constexpr (USE_COLWISE_SCALING) { - bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, + transformer_engine::rocm::bulk_tensor_2d_shared_to_global(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x, chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols); } @@ -393,165 +373,3 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) atomicMaxFloat(amax_ptr, block_amax); } } - -// Forward declaration of functions defined in `cast_kernels.cuh` -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream); - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream); - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream); - -constexpr size_t TILE_DIM = 32; -template -__global__ void partial_reduce_kernel(const DTypeReduce* input, float* partial_output, int rows, int cols) { - __shared__ float tile[TILE_DIM][TILE_DIM]; - - int tile_start_col = blockIdx.x * TILE_DIM; - int tile_start_row = blockIdx.y * TILE_DIM; - int thread_col_in_tile = threadIdx.x; - int thread_row_in_tile = threadIdx.y; - - int global_col = tile_start_col + thread_col_in_tile; - int global_row = tile_start_row + thread_row_in_tile; - - if (global_row < rows && global_col < cols) { - tile[thread_row_in_tile][thread_col_in_tile] = static_cast(input[global_row * cols + global_col]); - } else { - tile[thread_row_in_tile][thread_col_in_tile] = 0.0f; - } - __syncthreads(); - - for (int stride = TILE_DIM / 2; stride > 0; stride /= 2) { - if (thread_row_in_tile < stride) { - tile[thread_row_in_tile][thread_col_in_tile] += tile[thread_row_in_tile + stride][thread_col_in_tile]; - } - __syncthreads(); - } - - if (thread_row_in_tile == 0 && global_col < cols) { - partial_output[blockIdx.y * cols + global_col] = tile[0][thread_col_in_tile]; - } -} - -template -void reduce_dbias_rocm(const DTypeReduce *workspace_ptr, Tensor *dbias, const size_t rows, - const size_t cols, cudaStream_t stream, Tensor* partial_sum_workspace) { - dim3 block_dim_partial(TILE_DIM, TILE_DIM); - dim3 grid_dim_partial(DIVUP(cols, TILE_DIM), DIVUP(rows, TILE_DIM)); - - const size_t partial_rows = grid_dim_partial.y; - float* partial_workspace = reinterpret_cast(partial_sum_workspace->data.dptr); - - partial_reduce_kernel<<>>( - workspace_ptr, - partial_workspace, - rows, cols); - - reduce_dbias(partial_workspace, dbias, partial_rows, cols, stream); -} - -template -void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias, "DBias tensor must be provided when IS_DBIAS is true."); - NVTE_CHECK(workspace, "Workspace must be provided when IS_DBIAS is true."); - if (workspace->data.dptr == nullptr) { - if constexpr (IS_DACT) { - const size_t partial_rows = DIVUP(rows, TILE_DIM); - size_t total_elements = (rows * cols) + (partial_rows * cols); - workspace->data.shape = {total_elements}; - workspace->data.dtype = DType::kFloat32; - } else { - workspace->data.shape = {rows, cols}; - workspace->data.dtype = DType::kFloat32; - } - return; - } - - const void *ptr_to_reduce = nullptr; - DType dtype_to_reduce; - - workspace->amax = {}; - workspace->scale = {}; - workspace->scale_inv = {}; - - Tensor workspace_buffer; - Tensor partial_sum_buffer; - - if constexpr (IS_DACT) { - // The values to reduce are the result of the dAct function. - NVTE_CHECK(act_input, "Gradient tensor must be provided for DBias + DACT."); - - const size_t partial_rows = DIVUP(rows, TILE_DIM); - const size_t full_size_bytes = rows * cols * sizeof(float); - workspace_buffer = *workspace; - workspace_buffer.data.shape = {rows, cols}; - partial_sum_buffer.data.dptr = reinterpret_cast(workspace->data.dptr) + full_size_bytes; - partial_sum_buffer.data.shape = {partial_rows, cols}; - partial_sum_buffer.data.dtype = DType::kFloat32; - workspace = &partial_sum_buffer; - - CastVectorizedUnaryGradKernelLauncher(input, act_input, &workspace_buffer, stream); - if (output && output->data.dptr) { - CastVectorizedUnaryKernelLauncher(workspace_buffer, noop, output, stream); - } - ptr_to_reduce = workspace_buffer.data.dptr; - dtype_to_reduce = workspace_buffer.data.dtype; - } else { - if (output && output->data.dptr) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - // The values to reduce are just the input values. - ptr_to_reduce = input.data.dptr; - dtype_to_reduce = input.data.dtype; - } - - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias tensor."); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dbias->data.dtype, DBiasTypeOut, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - dtype_to_reduce, DTypeReduce, - reduce_dbias_rocm( - reinterpret_cast(ptr_to_reduce), - dbias, rows, cols, stream, workspace); - ); - ); - } else { - if (output && output->data.dptr) { - if constexpr (IS_DACT) { - NVTE_CHECK(act_input, "Gradient tensor must be provided for DACT output."); - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } else { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - - -} // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_vectorized_2d.cuh b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh similarity index 94% rename from transformer_engine/common/util/rocm_vectorized_2d.cuh rename to transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh index 5877ddd87..50474f308 100644 --- a/transformer_engine/common/util/rocm_vectorized_2d.cuh +++ b/transformer_engine/common/cast/mxfp8/rocm_vectorized_2d.cuh @@ -1,14 +1,14 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once -#include "../util/vectorized_pointwise.h" +#include "../../util/vectorized_pointwise.h" -namespace transformer_engine { +namespace transformer_engine::rocm { // These 2d copy functions replace TMA tensormap async copies for AMD GPUs. template __device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col, @@ -78,4 +78,4 @@ __device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T * } } } -} // namespace transformer_engine +} // namespace transformer_engine::rocm \ No newline at end of file diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh new file mode 100644 index 000000000..cff846490 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -0,0 +1,112 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file core_nvfp4.cuh + * \brief Core functions used in NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ + +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../util/curanddx.hpp" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif // FP4_TYPE_SUPPORTED + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +using nvfp4_scale_t = fp8e4m3; + +namespace quantization_and_transposition_SF { +#if FP4_TYPE_SUPPORTED +// Used in transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ nvfp4_scale_t compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + // constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const nvfp4_scale_t S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + // NOTE: Divide by 6.0f is not elegant and not efficient. + // However, this is part of the emulation code to ensure exact match. + using namespace detail; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + const float S_dec_b = block_amax / fp4_max * S_enc; + return static_cast(fminf(S_dec_b, TypeExtrema::max)); +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantization_and_transposition_SF + +namespace quantization_SF { +#if FP4_TYPE_SUPPORTED +// Used in non-transpose variant +// Compute per-block E4M3 encoding/decoding scaling factor +__device__ __forceinline__ fp8e4m3 compute_decoding_scaling_factor(const float block_amax, + const float S_enc) { + constexpr float rcp_6f = 1.0f / 6.0f; + // const float S_dec_b = block_amax * rcp_6f; + // const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // return S_dec_b_fp8; + return static_cast(block_amax * rcp_6f * S_enc); +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantization_SF + +namespace core { + +#if FP4_TYPE_SUPPORTED +using namespace ptx; + +// Compute the global encode scale factor for a given global amax +__device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { + using namespace detail; + constexpr float fp8_max = TypeExtrema::max; // 448.0f; + constexpr float fp4_max = TypeExtrema::max; // 6.0f; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.0f || global_encode_scale == 0.0f) { + return 1.0f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10> &rng, + // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4 &random_uint4, int &rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + random_uint4 = rng.generate4(); + } + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t *const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace core +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_CORE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh new file mode 100644 index 000000000..5307cad37 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -0,0 +1,115 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dequantize_nvfp4.cuh + * \brief CUDA kernels to dequantize from NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" + +#if FP4_TYPE_SUPPORTED +#include +#endif // FP4_TYPE_SUPPORTED + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace dequantize_kernel { +#if FP4_TYPE_SUPPORTED +template +__global__ void __launch_bounds__(512) + dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride) { + const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; + const size_t x = thread_idx % M; + const size_t y = thread_idx / M; + + if (y >= N) { + return; + } + + union fp4vec { + uint64_t vec; + fp4e2m1x4 small_vec[4]; + }; + using OVec = Vec; + const uint64_t *const input_vectorized = reinterpret_cast(input); + OVec *output_vec = reinterpret_cast(output); + + const size_t my_index = x + y * M; + const size_t my_scale_index = x + y * scale_stride; + const size_t my_output_index = (x + y * M) * 4; + fp4vec value; + value.vec = input_vectorized[my_index]; + fp8e4m3 scale = scales[my_scale_index]; + float amax = *tensor_amax; + constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float final_scale = static_cast(scale) * amax * factor_inv; +#pragma unroll + for (int i = 0; i < 4; i++) { + float4 current = static_cast(value.small_vec[i]); + OVec out; + out.data.elt[0] = static_cast(current.x * final_scale); + out.data.elt[1] = static_cast(current.y * final_scale); + out.data.elt[2] = static_cast(current.z * final_scale); + out.data.elt[3] = static_cast(current.w * final_scale); + output_vec[my_output_index + i] = out; + } +} +#endif // FP4_TYPE_SUPPORTED +} // namespace dequantize_kernel + +inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace dequantize_kernel; + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output"); + NVTE_CHECK(input.data.dtype == DType::kFloat4E2M1, "Input must have FP4 type."); + NVTE_CHECK(is_high_precision_dtype(output->data.dtype), "Output must be in higher precision."); + NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); + + constexpr int FP4_BLOCK_SIZE = 16; + const size_t N = input.flat_first_dim(); + const size_t M = input.flat_last_dim(); + + NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", + FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); + + const size_t Mread = M / FP4_BLOCK_SIZE; + const size_t total = N * Mread; + const size_t threads = 512; + const size_t blocks = DIVUP(total, threads); + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + output->data.dtype, OType, + + dequantize_fp4_kernel<<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back());); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); +#endif // FP4_TYPE_SUPPORTED +} +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_DEQUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh new file mode 100644 index 000000000..83ad8fd40 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -0,0 +1,688 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace quantize_kernel { + +using namespace ptx; +using namespace quantization_SF; +using namespace core; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 16; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; + +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 8 = 128 / 16 + +#define DIRECT_SCALING_FACTORS_STORE 1 + +template +__global__ void __launch_bounds__(THREADS_PER_CHUNK) + quantize_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output_rowwise, + const __grid_constant__ CUtensorMap tensor_map_output_colwise, + fp8e4m3 *const scales_rowwise_e4m3, e8m0_t *const scales_colwise_e8m0, + const float *noop, float *const amax_ptr, + const float *const nvfp4_second_stage_scale_ptr, const size_t rows, + const size_t cols, const size_t scale_stride_rowwise, + const size_t scale_stride_colwise) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool ROWWISE_SCALING = true; + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + constexpr size_t NVFP4_SCALING_FACTORS_PER_CHUNK_ROW = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_X_ROWWISE = NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + constexpr size_t THREADS_Y_ROWWISE = THREADS_PER_CHUNK / THREADS_X_ROWWISE; + + static_assert(BUFF_DIM_Y >= SCALE_DIM_Y && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); + static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + + constexpr size_t BUFF_IN_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_OUT_DIM_X = (CHUNK_DIM_X * 4) / 8; // Holds 2 elements of 4-bit size + constexpr size_t BUFF_IN_DIM = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t BUFF_OUT_DIM = BUFF_DIM_Y * BUFF_OUT_DIM_X; + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + + constexpr size_t ITERATIONS_ROWWISE = BUFF_DIM_Y / THREADS_Y_ROWWISE; + // static_assert(THREADS_PER_CHUNK >= CHUNK_DIM_X); // there should be a sufficient number of + // // threads to process one row in a single iteration + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; // Each thread processes two adjacent elements + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const bool rowwise_scale_is_within_bounds = scales_offset_X_rowwise < cols; + const bool colwise_scale_is_within_bounds = scales_offset_X_colwise < cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t buff_size_nvfp4_scales = + CHUNK_DIM_Y * (CHUNK_DIM_X / SCALE_DIM_X) * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y / SCALE_DIM_Y) * CHUNK_DIM_X * sizeof(fp8e8m0); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = (ROWWISE_SCALING ? buff_size_aligned_out_nvfp4 : 0); + constexpr size_t out_mem_colwise_data = (COLWISE_SCALING ? buff_size_aligned_out_mxfp8 : 0); + constexpr size_t out_mem_rowwise_scales = (ROWWISE_SCALING ? buff_size_nvfp4_scales : 0); + constexpr size_t out_mem_colwise_scales = (COLWISE_SCALING ? buff_size_mxfp8_scales : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_rowwise_data_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + fp8e4m3 *out_rowwise_scales_sh = + reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + e8m0_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factor for all S_dec_b + const float S_enc = + (nvfp4_second_stage_scale_ptr == nullptr) ? 1.0f : 1.0f / (*nvfp4_second_stage_scale_ptr); + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + const int buff_offset_in = buff * BUFF_IN_DIM; + const int buff_offset_out = buff * BUFF_OUT_DIM; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_IN_DIM; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff_offset_in + tid_X_colwise; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM_Y]; + IType in_colwise_IType[SCALE_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(block_amax * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + if (colwise_scale_is_within_bounds) { + scales_colwise_e8m0[scale_idx] = biased_exponent; + } + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + +// 3. Scale elements +#pragma unroll + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; + + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_IN_DIM_X; + out_colwise_data_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } + + if constexpr (ROWWISE_SCALING) { + const int stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (int it = 0; it < ITERATIONS_ROWWISE; ++it) { + const int it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const int shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const int shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const int it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const fp8e4m3 S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc); + +#if DIRECT_SCALING_FACTORS_STORE + // Check boundaries + if (rowwise_scale_is_within_bounds) { + const int scales_offset_Y = + scales_offset_Y_rowwise + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE; + const int scales_offset_X = scales_offset_X_rowwise; + const int scale_idx_global = scales_offset_Y * scale_stride_rowwise + scales_offset_X; + scales_rowwise_e4m3[scale_idx_global] = S_dec_b_fp8; + } +#else + const int shmem_scales_offset_Y = + stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise; + const int shmem_scales_offset_X = tid_X_rowwise; + const int scale_idx = + shmem_scales_offset_Y * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW + shmem_scales_offset_X; + out_rowwise_scales_sh[scale_idx] = S_dec_b_fp8; +#endif + // Compute "correct" per-block encoding scaling factor + const float block_scale_inverse = + __fdiv_rn(S_enc, static_cast(S_dec_b_fp8)); // S_enc_b_fp8 + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; // Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + in01 = in_IType[w].data.elt[2 * e]; + in23 = in_IType[w].data.elt[2 * e + 1]; + } else if constexpr (IS_CACHED_ACT_OP) { + in01.x = in_cached[w].data.elt[4 * e]; + in01.y = in_cached[w].data.elt[4 * e + 1]; + in23.x = in_cached[w].data.elt[4 * e + 2]; + in23.y = in_cached[w].data.elt[4 * e + 3]; + } else { + const int j = w * PACK_SIZE + 4 * e; + in01.x = in_compute_rowwise[j]; + in01.y = in_compute_rowwise[j + 1]; + in23.x = in_compute_rowwise[j + 2]; + in23.y = in_compute_rowwise[j + 3]; + } + fp4e2m1x4 &out_quad = reinterpret_cast(out.data.elt[e]); + ptx::mul_cvt_4x(out_quad, in01, in23, block_scale_inverse); + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + __builtin_assume(block_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset_nvfp4 = buff * BUFF_OUT_DIM; + const int buff_offset_mxfp8 = buff * BUFF_IN_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_data_sh[buff_offset_nvfp4])); + } + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_data_sh[buff_offset_mxfp8])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } + +#if !DIRECT_SCALING_FACTORS_STORE + // Vectorized store of scaling factors. + // Each thread stores multiple scaling factors in one store instruction. + if constexpr (ROWWISE_SCALING) { + // Number of scaling factors = CHUNK_DIM_X / SCALE_DIM_X + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + threadIdx.x; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise; + const int scale_idx_global = + scales_offset_Y_rowwise * scale_stride_rowwise + scales_offset_X_rowwise; + const int scale_idx_shmem = threadIdx.x * NVFP4_SCALING_FACTORS_PER_CHUNK_ROW; + + if ((threadIdx.x < CHUNK_DIM_Y) && (scales_offset_Y_rowwise < rows) && + (scales_offset_X_rowwise < (cols / SCALE_DIM_X))) { + using ScalesVec_t = Vec; + const ScalesVec_t &scales = + *reinterpret_cast(&out_rowwise_scales_sh[scale_idx_shmem]); + scales.store_to(&scales_rowwise_e4m3[scale_idx_global]); + } + } +#endif + + float chunk_amax = 0.0f; + if (amax_ptr != nullptr) { + const int warp_id = threadIdx.x / THREADS_PER_WARP; + // Reduce the amax over the block + chunk_amax = reduce_max(thread_amax, warp_id); + } + + if (is_master_thread && amax_ptr != nullptr) { + atomicMaxFloat(amax_ptr, chunk_amax); + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +} // namespace quantize_kernel + +// This kernel supports only two scaling cases: +// 1. r16c0 - Rowwise NVFP4 +// 2. r16c32 - Rowwise NVFP4 AND Colwise MXFP8 +inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_kernel; + using namespace ptx; + checkCuDriverContext(stream); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + NVTE_CHECK(output->has_data(), "NVFP4 Output tensor must be allocated."); + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + + bool use_colwise_scaling = output->has_columnwise_data(); + if (use_colwise_scaling) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Columnwise scaling tensor must be allocated"); + } + CheckNoopTensor(*noop, "cast_noop"); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + constexpr size_t CHUNK_DIM_Y = 128; + constexpr size_t CHUNK_DIM_X = 128; + constexpr size_t THREADS_PER_CHUNK = 128; + + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; + + const size_t scale_stride_rowwise = output->scale_inv.shape[1]; + const size_t scale_stride_colwise = + use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; + + fp8e4m3 *const scales_rowwise_e4m3_ptr = reinterpret_cast(output->scale_inv.dptr); + e8m0_t *const scales_colwise_e8m0_ptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; + + const ScalingType scaling_type = + use_colwise_scaling ? ScalingType::BIDIMENSIONAL : ScalingType::ROWWISE; + + float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const nvfp4_second_stage_scale_ptr = + reinterpret_cast(output->scale.dptr); + + // Output data type is only required for the column-wise MXFP8 scaling. + // It has no effect for the row-wise NVFP4 scaling, but is set to the default E4M3 for the macros to work + const DType output_data_type = + use_colwise_scaling ? output->columnwise_data.dtype : DType::kFloat8E4M3; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output_data_type, OType, alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, 4); + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(OType) * 8); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_nvfp4 = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out_mxfp8 = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_nvfp4_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(fp8e4m3); + constexpr size_t buff_size_mxfp8_scales = + (CHUNK_DIM_Y * CHUNK_DIM_X) / 32 * sizeof(e8m0_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + const size_t out_rowwise_data_mem = buff_size_aligned_out_nvfp4; + const size_t out_colwise_data_mem = use_colwise_scaling ? buff_size_aligned_out_mxfp8 : 0; + + const size_t out_rowwise_scales_mem = buff_size_nvfp4_scales; + const size_t out_colwise_scales_mem = use_colwise_scaling ? buff_size_mxfp8_scales : 0; + + const size_t out_mem = out_rowwise_data_mem + out_colwise_data_mem + + out_rowwise_scales_mem + out_colwise_scales_mem + + TMA_SHMEM_ALIGNMENT; + + const size_t dshmem_size = in_mem + out_mem; + + switch (scaling_type) { + case ScalingType::ROWWISE: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + case ScalingType::BIDIMENSIONAL: { + auto kernel = + quantize_nvfp4_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + + kernel<<>>( + tensor_map_input, tensor_map_output_rowwise, tensor_map_output_colwise, + scales_rowwise_e4m3_ptr, scales_colwise_e8m0_ptr, noop_ptr, amax_ptr, + nvfp4_second_stage_scale_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh new file mode 100644 index 000000000..7322bf265 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -0,0 +1,1287 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_transpose_nvfp4.cuh + * \brief CUDA kernels to cast to NVFP4 and transpose. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ + +#include +#include +#include +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../util/ptx.cuh" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +namespace quantize_transpose_kernel { + +using namespace quantization_and_transposition_SF; +using namespace core; +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr size_t SCALE_DIM = 16; // NVFP4 block (x16 elts) + +constexpr size_t CHUNK_DIM_Y = 128; +constexpr size_t CHUNK_DIM_X = 128; +constexpr size_t THREADS_NUM = 128; + +constexpr size_t SCALES_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM; + +constexpr size_t SCALES_PER_THREAD = 2 * (CHUNK_DIM_Y * CHUNK_DIM_X) / SCALE_DIM / THREADS_NUM; + +// Each call generates 4x uint32_t random numbers +constexpr size_t RNG_GENS_PER_THREAD = SCALES_PER_THREAD / 4; + +constexpr size_t TILE_DIM_Y = 32; +constexpr size_t TILE_DIM_X = 128; + +// SHould this be SCALE_DIM or BLOCK_DIM? Both are 16, should work for both 1D and 2D +constexpr size_t SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM; +constexpr size_t SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM; // 128 / 16 = 8 + +constexpr size_t TILES_Y = CHUNK_DIM_Y / TILE_DIM_Y; +constexpr size_t TILES_X = CHUNK_DIM_X / TILE_DIM_X; +constexpr size_t STAGES = TILES_Y * TILES_X; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = TILE_DIM_Y; +constexpr size_t BUFF_DIM_X = TILE_DIM_X; +constexpr size_t BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X; +constexpr size_t BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM; + +// Input buffer (BF16) +constexpr size_t BUFF_IN_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_IN_DIM_X = BUFF_DIM_X; +constexpr size_t BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X; + +// Output buffer (NVFP4) +constexpr size_t BUFF_OUT_DIM_Y = BUFF_DIM_Y; +constexpr size_t BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8; +constexpr size_t BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X; + +// Output transpose buffer (NVFP4) +constexpr size_t BUFF_OUT_T_DIM_Y = BUFF_DIM_X; +constexpr size_t BUFF_OUT_T_DIM_X = (BUFF_DIM_Y * 4) / 8; +constexpr size_t BUFF_OUT_T_SIZE = BUFF_OUT_T_DIM_Y * BUFF_OUT_T_DIM_X; + +// Manual swizzling parameters to reduce SHMEM bank conflicts +constexpr size_t PACK_SIZE = 8; +constexpr size_t WAVES = SCALE_DIM / PACK_SIZE; + +constexpr size_t SCALING_FACTORS_PER_TILE_X = TILE_DIM_X / SCALE_DIM; +constexpr size_t THREADS_X_ROWWISE = SCALING_FACTORS_PER_TILE_X; // 128 / 16 = 8 +constexpr size_t THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE; // 128 / 8 = 16 + +constexpr size_t ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE; // 32/ 16 = 2 +constexpr size_t ITERATIONS_TRANSPOSE = BUFF_IN_DIM_Y / SCALE_DIM; +constexpr size_t BUFF_OUT_IT_OFFSET = BUFF_OUT_T_DIM_X / ITERATIONS_TRANSPOSE; + +static_assert(BUFF_DIM_Y >= SCALE_DIM && + "Number of buffer rows must be greater or equal to the size of the columwise " + "scaling block\0"); +static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y); +static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE && + "Number of buffer rows must be greater or equal to the number of rowwise " + "processing threads in Y dimension\0"); + +// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + int rnd_idx = 0; + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + float thread_amax = 0.0f; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = 0.0f; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType block_amax_f16 = static_cast(0.0f); +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + block_amax_f16 = __hmax(block_amax_f16, __habs(in_colwise_IType[i])); + } + block_amax = static_cast(block_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = + (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements + fp4e2m1x4 regs[SCALE_DIM / 4]; + +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + const size_t it_offset_Y = stage_offset_Y + it * THREADS_Y_ROWWISE; + + block_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + block_amax = fmaxf(block_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + block_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + it_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = + (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + block_amax = fmaxf(block_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + block_amax = fmaxf(block_amax, fabsf(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + +// 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void __launch_bounds__(THREADS_NUM) + quantize_transpose_nvfp4_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + const __grid_constant__ CUtensorMap tensor_map_output_t, + nvfp4_scale_t *const scales_ptr, + nvfp4_scale_t *const scales_t_ptr, const float *noop, + const float *const amax_rowwise_ptr, + const float *const amax_colwise_ptr, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const size_t *rng_state) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + constexpr bool NO_ACTIVATIONS_NOT_FP32_INPUT = + (!COMPUTE_ACTIVATIONS) && (!std::is_same_v); + + using IType2 = typename ptx::FPx2; + + if constexpr (!COMPUTE_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + } + const size_t rng_sequence = + threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0}; + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + // NEW: 2D Block-based scaling constants + constexpr size_t BLOCK_DIM = 16; + constexpr size_t BLOCKS_PER_TILE_Y = TILE_DIM_Y / BLOCK_DIM; // 32/16 = 2 + constexpr size_t BLOCKS_PER_TILE_X = TILE_DIM_X / BLOCK_DIM; // 128/16 = 8 + constexpr size_t ITERATIONS_BLOCK = 2; // iterations to calculate 2d block amaxes of 1 tile + constexpr size_t BLOCKS_PER_WARP = BLOCKS_PER_TILE_X / (THREADS_NUM / 32); // 8 / (128/32) = 2 + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS; + + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + + const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + + const size_t chunk_rows = rows - block_offset_Y; + + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_X_colwise = threadIdx.x; + const size_t tid_Y_t = tid_X_colwise; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + const size_t scales_offset_X_t = scales_block_offset_X_t; + + const size_t SFs_per_row = cols / SCALE_DIM; + + const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + + // Helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_mem_rowwise_data = buff_size_aligned_out; + constexpr size_t out_mem_colwise_data = buff_size_aligned_out; + constexpr size_t out_mem_rowwise_scales = 0; + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + fp4e2m1x2 *out_data_sh = reinterpret_cast(dshmem + in_mem); + fp4e2m1x2 *out_t_data_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise_data); + + nvfp4_scale_t *out_rowwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data); + nvfp4_scale_t *out_colwise_scales_sh = reinterpret_cast( + dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); + + // Compute a global encoding/decoding scaling factors for all S_dec_b + const float S_enc_rowwise = (amax_rowwise_ptr == nullptr) + ? 1.0f + : compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr); + // NOTE: This is to match with how emulation code was written. + const float S_dec_rowwise = 1.0 / S_enc_rowwise; + + const float S_enc_colwise = (amax_colwise_ptr == nullptr) + ? S_enc_rowwise + : compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr); + const float S_dec_colwise = 1.0 / S_enc_colwise; + + const size_t warp_id = threadIdx.x / 32; + const size_t lane_id = threadIdx.x % 32; + float thread_amax = 0.0f; + const size_t block_in_warp = lane_id / BLOCKS_PER_WARP; + +// Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + __shared__ alignas(8) uint64_t mbar[STAGES]; + + __shared__ __align__(16) float block_amax_matrix[BLOCKS_PER_TILE_Y][BLOCKS_PER_TILE_X + 1]; + + // Helper function for warp reduction + auto warp_reduce_amax = [](float thread_amax, int block_in_warp) -> float { +#pragma unroll + for (int delta = 8; delta >= 1; delta /= 2) { + float other_amax = __shfl_xor_sync(0xffffffff, thread_amax, delta); + thread_amax = fmaxf(thread_amax, other_amax); + } + return thread_amax; + }; + + initialize_barriers(mbar, is_master_thread); + + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; + + const size_t buff_offset_in = buff * BUFF_IN_SIZE; + const size_t buff_offset_out = buff * BUFF_OUT_SIZE; + const size_t buff_offset_out_t = buff * BUFF_OUT_T_SIZE; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_IN_SIZE; + + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + float block_amax = 0.0f; + +#pragma unroll + for (size_t block_iter = 0; block_iter < ITERATIONS_BLOCK; ++block_iter) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + const size_t block_in_tile_y = block_iter; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + for (int elem = 0; elem < BLOCK_DIM; elem += 2) { + const size_t elem_0_row = block_iter * BLOCK_DIM + elem; + const size_t elem_1_row = elem_0_row + 1; + const size_t elem_0_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + const size_t elem_1_col = elem_0_col; + + const size_t shmem_offset_0 = buff_offset_in + elem_0_row * BUFF_IN_DIM_X + elem_0_col; + const size_t shmem_offset_1 = buff_offset_in + elem_1_row * BUFF_IN_DIM_X + elem_1_col; + + IType2 val_2x; + val_2x.x = in_sh[shmem_offset_0]; + val_2x.y = in_sh[shmem_offset_1]; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, val_2x); + } + + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else { + for (int elem = 0; elem < BLOCK_DIM; ++elem) { + const size_t elem_row = block_iter * BLOCK_DIM + elem; + const size_t elem_col = warp_id * BLOCKS_PER_WARP * BLOCK_DIM + lane_id; + + // Bounds checking + const bool row_out_of_bounds = (block_offset_Y + stage_offset_Y + elem_row >= rows); + const bool col_out_of_bounds = (block_offset_X + elem_col >= cols); + if (!row_out_of_bounds && !col_out_of_bounds) { + const size_t shmem_offset = buff_offset_in + elem_row * BUFF_IN_DIM_X + elem_col; + float elt = static_cast(in_sh[shmem_offset]); + + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset] = static_cast(elt); + } + + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } + } + // Warp reduction to get block amax + block_amax = warp_reduce_amax(thread_amax, block_in_warp); + + if (lane_id == 0 || lane_id == 16) { + block_amax_matrix[block_in_tile_y][block_in_tile_x] = block_amax; + } + } + + // sync thread to ensure block_amax_matrix is done storing + __syncthreads(); + + // COLWISE scaling + if constexpr (RETURN_TRANSPOSE) { +#pragma unroll + for (size_t it = 0; it < ITERATIONS_TRANSPOSE; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = threadIdx.x / BLOCK_DIM; + + const size_t in_thread_offset_Y = 0 + it * SCALE_DIM; + const size_t in_thread_offset_X = thread_offset_X_colwise; + + const size_t out_t_thread_offset_Y = thread_offset_X_colwise; + const size_t out_t_thread_offset_X = 0 + it * BUFF_OUT_IT_OFFSET; + + const size_t shmem_offset_base_colwise_in = + buff_offset_in + in_thread_offset_Y * BUFF_IN_DIM_X + in_thread_offset_X; + const size_t shmem_offset_base_colwise_out_t = + buff_offset_out_t + out_t_thread_offset_Y * BUFF_OUT_T_DIM_X + out_t_thread_offset_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_colwise[SCALE_DIM]; + IType in_colwise_IType[SCALE_DIM]; + // 3. Scale elements + + // Load data in + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { +#pragma unroll + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + } + } else { + for (int i = 0; i < SCALE_DIM; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise_in + i * BUFF_IN_DIM_X; + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } + + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_colwise); + + // // Store scaling factors through SHMEM + const size_t scale_idx_sh = + tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; + out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_colwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + fp4e2m1x4 regs[SCALE_DIM / 4]; +#pragma unroll + for (int e = 0; e < SCALE_DIM / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_colwise_IType[4 * e]); + regs[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const float2 in01 = *reinterpret_cast(&in_compute_colwise[4 * e]); + const float2 in23 = *reinterpret_cast(&in_compute_colwise[4 * e + 2]); + regs[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const int group = thread_lane / 16; + uint32_t val[2]; + uint32_t *regs_4x = reinterpret_cast(regs); + + // Helps reducing bank conflicts + switch (group) { + case 0: + val[0] = regs_4x[0]; + val[1] = regs_4x[1]; + break; + case 1: + val[0] = regs_4x[1]; + val[1] = regs_4x[0]; + break; + } + uint32_t *out_t_data_sh_as_uint32_t = + reinterpret_cast(&out_t_data_sh[shmem_offset_base_colwise_out_t]); + out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; + out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; + } + } + + // ROWWISE scaling + { + const size_t stage_rowwise_scales_offset_Y = stage * BUFF_DIM_Y; +#pragma unroll + for (size_t it = 0; it < ITERATIONS_NORMAL; ++it) { + const size_t block_in_tile_y = it; + const size_t block_in_tile_x = tid_X_rowwise; + const size_t it_thread_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE; + + const size_t shmem_offset_base_rowwise_in = + buff_offset_in + it_thread_offset_Y_rowwise * BUFF_IN_DIM_X; + const size_t shmem_offset_base_rowwise_out = + buff_offset_out + it_thread_offset_Y_rowwise * BUFF_OUT_DIM_X; + + block_amax = block_amax_matrix[block_in_tile_y][block_in_tile_x]; + float in_compute_rowwise[SCALE_DIM]; + Vec in_cached[WAVES]; + + // used as an IType container for BF16/FP16 --> NVFP4 CAST ONLY + Vec in_IType[WAVES]; + + // 1. Read/Compute elements. Find NVFP4-block AMAX + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); + } + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_in + swizzled_thread_idx; + + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const size_t j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); + if constexpr (COMPUTE_ACTIVATIONS) { + elt = OP(elt, {}); + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + in_compute_rowwise[j] = elt; + } + } + } + + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } + + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + const float block_scale_inverse = fminf( + 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 + const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; + + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 4; ++e) { + const uint32_t rbits = get_rbits(rng, random_uint4, rnd_idx); + IType2 in01; + IType2 in23; + if constexpr (NO_ACTIVATIONS_NOT_FP32_INPUT) { + const uint64_t elts = *reinterpret_cast(&in_IType[w].data.elt[2 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else if constexpr (IS_CACHED_ACT_OP) { + const uint64_t elts = *reinterpret_cast(&in_cached[w].data.elt[4 * e]); + out.data.elt[e] = ptx::mul_cvt_bf16_to_fp4_4x( + elts, block_scale_inverse_2x, rbits); + } else { + const int j = w * PACK_SIZE + 4 * e; + const float2 in01 = make_float2(in_compute_rowwise[j], in_compute_rowwise[j + 1]); + const float2 in23 = make_float2(in_compute_rowwise[j + 2], in_compute_rowwise[j + 3]); + out.data.elt[e] = ptx::mul_cvt_fp32_to_fp4_4x( + in01, in23, block_scale_inverse_2x, rbits); + } + } + + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise_out + swizzled_idx / 2; + out.store_to(&out_data_sh[shmem_offset_rowwise]); + } + } + } + + __builtin_assume(thread_amax >= 0); + thread_amax = fmaxf(thread_amax, block_amax); + + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. + + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + + const size_t global_offset_Y_t = block_offset_Y_t; + const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, + reinterpret_cast(&out_data_sh[buff_offset_out])); + + if constexpr (RETURN_TRANSPOSE) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + } + + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } + } // end of stages + + // Vectorized store scaling factors through SHMEM + if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + using ScalesVec = Vec; + const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + const size_t count = // number of scales in Y dimension of this chunk + (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // Fast path: vectorized store when destination is properly aligned + scales_vec.store_to(dst); + } else { + // Safe path: element-wise store for tails or unaligned destinations + scales_vec.store_to_elts(dst, 0, count); + } + } + + destroy_barriers(mbar, is_master_thread); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} +#endif // FP4_TYPE_SUPPORTED +} // namespace quantize_transpose_kernel + +template +void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_transpose_kernel; + using namespace ptx; + bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. + // TODO(Frank): Is there a better way to do this? + bool return_transpose = output->has_columnwise_data(); + + constexpr bool COMPUTE_ACTIVATIONS = false; + using ParamOP = Empty; + constexpr float (*OP)(float, const ParamOP &) = nullptr; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); + NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + if (return_transpose) { + NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated"); + } + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + + NVTE_CHECK(rows % 32 == 0, + "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA + NVTE_CHECK(cols % 32 == 0, + "Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_NUM; + + const size_t scale_stride = output->scale_inv.shape[1]; + const size_t scale_stride_transpose = + return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); + nvfp4_scale_t *const scales_transpose_ptr = + reinterpret_cast(output->columnwise_scale_inv.dptr); + + const float *noop_ptr = reinterpret_cast(noop->data.dptr); + const float *const amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const float *const amax_colwise_ptr = + reinterpret_cast(output->columnwise_amax.dptr); + + const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr; + const size_t *rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + using IType = bf16; + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_output{}; + alignas(64) CUtensorMap tensor_map_output_transpose{}; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + sizeof(IType) * 8); + + create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, + 4); + if (return_transpose) { + create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + } + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE((buff_elems_total * 4) / 8, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_scales = (CHUNK_DIM_Y * CHUNK_DIM_X) / 16 * sizeof(nvfp4_scale_t); + + constexpr size_t in_mem = buff_size_aligned_in; + + constexpr size_t out_data_mem = buff_size_aligned_out; + constexpr size_t out_data_transpose_mem = buff_size_aligned_out; + constexpr size_t out_scales_transpose_mem = buff_size_scales; + + constexpr size_t out_mem = out_data_mem + out_data_transpose_mem; + + constexpr size_t dshmem_size = in_mem + out_mem + out_scales_transpose_mem + TMA_SHMEM_ALIGNMENT; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; + + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_CUH_ diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index ec29e6e12..56369db27 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -64,6 +64,15 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl #endif _comm_created = true; } + + initialize(tp_size, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, + num_comm_sm, set_sm_margin, use_ce, atomic_gemm); +} + +void CommOverlapCore::initialize(int tp_size, int num_splits, int num_max_streams, + int comm_cga_size, int gemm_priority, int comm_priority, + int num_comm_sm, bool set_sm_margin, bool use_ce, + bool atomic_gemm) { _use_ce = static_cast(use_ce); _num_comm_sm = num_comm_sm; _cga_size = comm_cga_size; @@ -278,6 +287,11 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType allgather_handle, barrier_handle, num_splits, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, false, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, rs_overlap_first_gemm); +} + +void CommOverlapBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm) { _rs_overlap_first_gemm = rs_overlap_first_gemm; _rs_kernel_type = getenv("NVTE_RS_STRIDED_ATOMIC", 0); NVTE_CHECK(_rs_kernel_type >= 0 && _rs_kernel_type <= 3, @@ -288,7 +302,9 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_ub_comm->myrank == 0) printf("!!! [UB] Register UBuf %d\n", _ub_reg); + if (_ub_comm->myrank == 0) { + printf("!!! [UB] Register UBuf %d\n", _ub_reg); + } _ubuf = TensorWrapper(buffer_ptr, buffer_shape, buffer_dtype); NVTE_CHECK_CUDA( @@ -640,6 +656,11 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, allgather_handle, barrier_handle, tp_size, num_max_streams, comm_cga_size, gemm_priority, comm_priority, num_comm_sm, set_sm_margin, use_ce, atomic_gemm) { + initialize(buffer_shape, buffer_dtype, comm_type, aggregate); +} + +void CommOverlapP2PBase::initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate) { _is_p2p = true; _is_reduce_scatter = comm_type == CommOverlapType::RS; _aggregate = aggregate; @@ -647,28 +668,28 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, // Create workspace tensor with userbuffer NVTE_CHECK(buffer_shape.size() == 2, "Userbuffer shape must be 2-dimensional!"); size_t buffer_bytes = get_buffer_size_bytes(buffer_shape[0], buffer_shape[1], buffer_dtype); - int buffer_chunk_bytes = buffer_bytes / tp_size; - _num_ubuf_chunks = tp_size; + int buffer_chunk_bytes = buffer_bytes / _tp_size; + _num_ubuf_chunks = _tp_size; if (_is_reduce_scatter) { // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk // outputs for reduction at the end of the pipelining. - buffer_bytes = buffer_bytes / tp_size * (tp_size * 2 - 1); - _num_ubuf_chunks = tp_size * 2 - 1; + buffer_bytes = buffer_bytes / _tp_size * (_tp_size * 2 - 1); + _num_ubuf_chunks = _tp_size * 2 - 1; } void *buffer_ptr; _ub_reg = register_user_buffer_collective(&buffer_ptr, buffer_bytes, _ub_comm, true); - if (_rank == 0) printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); + if (_rank == 0) printf("!!! [UBP2P] UBuf %d\n", _ub_reg); _ubuf = TensorWrapper( buffer_ptr, - std::vector{buffer_shape[0] / tp_size * _num_ubuf_chunks, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size * _num_ubuf_chunks, buffer_shape[1]}, buffer_dtype); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(buffer_ptr); for (int i = 0; i < _num_ubuf_chunks; i++) { _ubufs.push_back(TensorWrapper(reinterpret_cast(ubuf_byte_ptr), - std::vector{buffer_shape[0] / tp_size, buffer_shape[1]}, + std::vector{buffer_shape[0] / _tp_size, buffer_shape[1]}, buffer_dtype)); ubuf_byte_ptr += buffer_chunk_bytes; } @@ -691,7 +712,7 @@ CommOverlapP2PBase::CommOverlapP2PBase(const std::vector &buffer_shape, NVTE_CHECK_CUDA(cudaMemset(_counter.dptr(), 0, sizeof(int32_t))); } - for (int i = 0; i < std::min(num_max_streams, _tp_size); i++) { + for (int i = 0; i < _stream_compute.size(); i++) { cudaStream_t stream; NVTE_CHECK_CUDA(cudaStreamCreateWithPriority(&stream, cudaStreamNonBlocking, _comm_priority)); _stream_send.push_back(std::move(stream)); @@ -711,6 +732,38 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { } } +void CommOverlapP2PBase::copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, + bool local_chunk, bool rowwise) { + // Check element size + const size_t element_size = source.element_size(); + NVTE_CHECK(_ubuf.element_size() == element_size, + "Tried to copy data into a Userbuffers buffer but dtypes are not compatible ", + "(source dtype has ", element_size, " bytes, UB dtype has ", _ubuf.element_size(), + " bytes)"); + + // Input data + const size_t source_size = source.numel(); + const void *src_ptr = (rowwise) ? source.dptr() : source.columnwise_dptr(); + + // Userbuffers data + void *dst_ptr; + if (local_chunk) { + NVTE_CHECK(_ubufs[_tp_id].numel() == source_size, + "Tried to copy an invalid tensor into a local chunk of a Userbuffers buffer ", + "(source_size=", source_size, ", local_ubuf_size=", _ubufs[_tp_id].numel(), ")"); + dst_ptr = _ubufs[_tp_id].dptr(); + } else { + NVTE_CHECK(_ubuf.numel() == source_size, + "Tried to copy an invalid tensor into a Userbuffers buffer ", + "(source_size=", source_size, ", ubuf_size=", _ubuf.numel(), ")"); + dst_ptr = _ubuf.dptr(); + } + + // Copy data + NVTE_CHECK_CUDA(cudaMemcpyAsync(dst_ptr, src_ptr, source_size * element_size, + cudaMemcpyDeviceToDevice, stream)); +} + TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, size_t chunk_id) { // Start with a chunk of the source tensor @@ -851,6 +904,15 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, const bool do_gelu = pre_gelu_out.numel() > 0; size_t workspace_size_chunk = workspace.numel() / _stream_compute.size(); + // Check B copy sizing + if (B_copy.numel() > 0) { + NVTE_CHECK(B_copy.numel() == _ubuf.numel(), "Expected all-gathered B copy buffer with ", + _ubuf.numel(), " elements but got ", B_copy.numel()); + NVTE_CHECK(B_copy.element_size() == _ubuf.element_size(), + "Expected all-gathered B copy buffer with ", _ubuf.element_size() * 8, + "-bit data type but got ", B_copy.element_size() * 8, "-bit"); + } + NVTE_CHECK_CUDA(cudaEventRecord(_start_compute, stream_main)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _start_compute, 0)); NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_recv, _start_compute, 0)); @@ -919,12 +981,6 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } else { @@ -972,16 +1028,16 @@ void CommOverlapP2PBase::split_overlap_ag(const TensorWrapper &A, bool transa, NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_send[0], _stop_recv, 0)); NVTE_CHECK_CUDA( cudaStreamWaitEvent(_stream_compute[(i + 1) % _stream_compute.size()], _stop_recv, 0)); - } else if (B_copy.numel() > 0) { - assert(B_copy.numel() == _ubufs[_tp_id].numel()); - assert(B_copy.element_size() == _ubufs[_tp_id].element_size()); - NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubufs[_tp_id].dptr(), - _ubufs[_tp_id].bytes(), cudaMemcpyDeviceToDevice, - _stream_send[0])); } } } + // Copy all-gathered B from communication buffer into auxiliary output + if (B_copy.numel() > 0) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(B_copy.dptr(), _ubuf.dptr(), _ubuf.bytes(), + cudaMemcpyDeviceToDevice, _stream_send[0])); + } + _ub_comm->sms = ori_sms; for (size_t i = 0; i < _stream_compute.size(); i++) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_compute, _stream_compute[i])); diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 1ce89c512..6c7bed55a 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -670,9 +670,36 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * reinterpret_cast(&memhndl), sizeof(cudaIpcMemHandle_t), comm->comm_intra); + // Check for NVLINK support before attempting IPC operations + if (comm->nvsize > 1) { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + cudaDeviceProp deviceProp; + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, current_device)); + bool peer_access_available = false; + for (int i = 0; i < comm->nvsize; i++) { + if (i != comm->nvrank) { + int can_access_peer; + cudaError_t peer_result = cudaDeviceCanAccessPeer(&can_access_peer, current_device, i); + if (peer_result == cudaSuccess && can_access_peer) { + peer_access_available = true; + break; + } + } + } + if (!peer_access_available) { + free(tmp); + NVTE_ERROR( + "No peer-to-peer access available between GPUs. This platform does not support the " + "GPU-to-GPU " + "communication required for multi-GPU userbuffers. Consider using single-GPU mode."); + return 1; + } + } + for (int i = 0; i < comm->nvsize; i++) { if (i != comm->nvrank) { - NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], // NOLINT(*) + NVTE_CHECK_CUDA(cudaIpcOpenMemHandle(&(comm->peer_ptr[hndl][i]), tmp[i], cudaIpcMemLazyEnablePeerAccess)); } } @@ -693,4 +720,5 @@ int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator * comm->mem_ptr[hndl] = *gpubuff; return comm->free_region++; + printf("***** Returning *****\n"); } diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index e67694c38..a0193e95f 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -42,6 +42,10 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { return CUDA_R_8F_E4M3; case DType::kFloat8E5M2: return CUDA_R_8F_E5M2; +#if CUDA_VERSION >= 12080 + case DType::kFloat4E2M1: + return CUDA_R_4F_E2M1; +#endif default: NVTE_ERROR("Invalid type"); } @@ -165,7 +169,9 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits) { + const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { @@ -174,6 +180,8 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, }(); // rank is the number of dimensions of the array constexpr uint32_t rank = 2; + + // Dimension for the packed data types must reflect the number of individual U# values. uint64_t size[rank] = {globalX, globalY}; // The stride is the number of bytes to traverse from the first element of one row to the next @@ -212,7 +220,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE, // Swizzling can be used to avoid shared memory bank conflicts. - CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE, + swizzle, // L2 Promotion can be used to widen the effect of a cache-policy to a wider // set of L2 cache lines. @@ -222,13 +230,13 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, // Any element that is outside of bounds will be set to zero by the TMA transfer. CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); } -#endif //#ifndef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100() { int deviceComputeCapability = cuda::sm_arch(cuda::current_device()); return deviceComputeCapability >= 100; } +#endif //#ifndef __HIP_PLATFORM_AMD__ std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size) { diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index ce510334b..5feeb600c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -54,8 +54,14 @@ inline bool is_delayed_tensor_scaling(const NVTEScalingMode &mode) { return mode == NVTE_DELAYED_TENSOR_SCALING; } +inline bool is_nvfp4_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + +inline bool is_mxfp8_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } + inline bool is_mxfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_MXFP8_1D_SCALING; } +inline bool is_nvfp_scaling(const NVTEScalingMode &mode) { return mode == NVTE_NVFP4_1D_SCALING; } + inline size_t product(const std::vector &shape, const size_t begin, const size_t end) { NVTE_CHECK(begin <= end && end <= shape.size(), "Attempted to access entries ", begin, " to ", end, " in a vector with ", shape.size(), " entries"); @@ -114,6 +120,7 @@ struct Tensor { SimpleTensor data; SimpleTensor columnwise_data; SimpleTensor amax; + SimpleTensor columnwise_amax; SimpleTensor scale; SimpleTensor scale_inv; SimpleTensor columnwise_scale_inv; @@ -125,6 +132,7 @@ struct Tensor { : data(), columnwise_data(), amax(nullptr, {1}, DType::kFloat32), + columnwise_amax(nullptr, {1}, DType::kFloat32), scale(nullptr, {1}, DType::kFloat32), scale_inv(nullptr, {1}, DType::kFloat32), columnwise_scale_inv(nullptr, {1}, DType::kFloat32), @@ -135,6 +143,7 @@ struct Tensor { data.clear(); columnwise_data.clear(); amax.clear(); + columnwise_amax.clear(); scale.clear(); scale_inv.clear(); columnwise_scale_inv.clear(); @@ -181,21 +190,38 @@ struct Tensor { */ switch (scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: - if (!has_data() && has_columnwise_data()) { + case NVTE_NVFP4_1D_SCALING: { + // Choose data buffer based on whether it is initialized + // Note: Uninitialized buffers currently have shape=[]. + // However, this is logically incorrect. 0-D tensors have 1 + // entry, and uninitialized tensors should have shape=[0]. + bool use_columnwise_shape = false; + if (data.dptr != nullptr) { + use_columnwise_shape = false; + } else if (columnwise_data.dptr != nullptr) { + use_columnwise_shape = true; + } else if (data.shape.size() != 0) { + use_columnwise_shape = false; + } else if (columnwise_data.shape.size() != 0) { + use_columnwise_shape = true; + } + + // Infer shape based on data + if (use_columnwise_shape) { + // Column-wise data is transposed std::vector ret; if (!columnwise_data.shape.empty()) { + ret.reserve(columnwise_data.shape.size()); for (size_t i = 1; i < columnwise_data.shape.size(); i++) { ret.push_back(columnwise_data.shape[i]); } ret.push_back(columnwise_data.shape.front()); } return ret; - } else { - return data.shape; } - break; + return data.shape; + } case NVTE_MXFP8_1D_SCALING: - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: if (!has_data() && has_columnwise_data()) { return columnwise_data.shape; } else { @@ -267,12 +293,18 @@ struct QuantizationConfig { NVTETensor noop_tensor = nullptr; Float8BlockScaleTensorFormat float8_block_scale_tensor_format = Float8BlockScaleTensorFormat::GEMM_READY; + NVTETensor rng_state = nullptr; + bool nvfp4_2d_quantization = false; + bool stochastic_rounding = false; static constexpr size_t attr_sizes[] = { - sizeof(bool), // force_pow_2_scales - sizeof(float), // amax_epsilon - sizeof(NVTETensor), // noop_tensor - sizeof(Float8BlockScaleTensorFormat) // float8_block_scale_tensor_format + sizeof(bool), // force_pow_2_scales + sizeof(float), // amax_epsilon + sizeof(NVTETensor), // noop_tensor + sizeof(Float8BlockScaleTensorFormat), // float8_block_scale_tensor_format + sizeof(NVTETensor), // rng_seed and offset + sizeof(bool), // nvfp4_2d_quantization + sizeof(bool) // stochastic_rounding }; }; @@ -322,6 +354,8 @@ using fp8e8m0 = __nv_fp8_e8m0; #endif // CUDA_VERSION >= 12080 #if FP4_TYPE_SUPPORTED using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; #endif //FP4_TYPE_SUPPORTED #else using bf16 = hip_bfloat16; @@ -370,6 +404,7 @@ struct TypeExtrema; template <> struct TypeExtrema { static constexpr float max = 6.0f; + static constexpr float max_inverse = 1.0 / max; }; #endif @@ -377,16 +412,20 @@ template <> struct TypeExtrema { #ifndef __HIP_PLATFORM_AMD__ static constexpr float max = 448.0f; + static constexpr float max_inverse = 1.0 / max; #elif defined(__HIP_DEVICE_COMPILE__) - static constexpr float maxNorm = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max = te_fp8_fnuz() ? 240.0f : 448.0f; + static constexpr float max_inverse = 1.0 / max; #else - static float maxNorm; + static float max; + static float max_inverse; #endif }; template <> struct TypeExtrema { static constexpr float max = 57344.0f; + static constexpr float max_inverse = 1.0 / max; }; template <> @@ -600,6 +639,18 @@ struct TypeInfo { NVTE_ERROR("Invalid type."); \ } +// Add a pack_size argument to select the packed type for FP4 +#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(dtype, pack_size, type, ...) \ + switch (dtype) { \ + using namespace transformer_engine; \ + case DType::kFloat4E2M1: { \ + using type = __nv_fp4x2_storage_t; \ + { __VA_ARGS__ } \ + } break; \ + default: \ + NVTE_ERROR("Invalid type."); \ + } + #define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(dtype, type, ...) \ switch (dtype) { \ using namespace transformer_engine; \ @@ -764,13 +815,15 @@ void checkCuDriverContext(CUstream stream); CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. -void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, - const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, - const uint32_t shmemX, const uint32_t stride_elems, - const uint32_t offset_elems, const size_t type_num_bits); -#endif //#ifdef __HIP_PLATFORM_AMD__ +void create_2D_tensor_map( + CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, + const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, + const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits, + const CUtensorMapSwizzle swizzle = CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE); bool is_supported_by_CC_100(); +#endif //#ifdef __HIP_PLATFORM_AMD__ + std::vector> convert_tensor_array(NVTETensor **nvte_tensors, size_t outer_size, size_t inner_size); diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 795697635..611beb7b8 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -15,6 +15,74 @@ #include "fused_attn_fp8.h" #include "utils.h" +namespace { +// Helper function to create a tensor view with modified shape and optional pointer offset +transformer_engine::Tensor make_tensor_view(const transformer_engine::Tensor *source, + const std::vector &shape, + size_t offset_bytes = 0) { + transformer_engine::Tensor view = *source; + if (offset_bytes > 0) { + view.data.dptr = static_cast(static_cast(source->data.dptr) + offset_bytes); + } + view.data.shape = shape; + view.nvte_tensor = 0; // Mark as unmanaged/local tensor view + return view; +} + +// Helper function to calculate stride in bytes for packed QKV tensor unpacking +size_t calculate_qkv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for QKV packed tensor +std::vector calculate_qkv_unpacked_shape(const transformer_engine::Tensor *qkv_tensor, + size_t h, size_t d) { + std::vector unpacked_shape; + if (qkv_tensor->data.shape.size() == 4) { + // T3HD or TH3D (4D) -> THD (3D): remove dimension "3" at position 1 + unpacked_shape = {qkv_tensor->data.shape[0], h, d}; + } else { + // BS3HD/SB3HD or BSH3D/SBH3D (5D) -> BSHD/SBHD (4D): remove dimension "3" at position 2 + unpacked_shape = {qkv_tensor->data.shape[0], qkv_tensor->data.shape[1], h, d}; + } + return unpacked_shape; +} + +// Helper function to calculate stride for packed KV tensor unpacking +size_t calculate_kv_stride(NVTE_QKV_Layout_Group layout_group, transformer_engine::DType dtype, + size_t h_kv, size_t d) { + size_t stride = 0; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + stride = (transformer_engine::typeToNumBits(dtype) * h_kv * d) / 8; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + stride = (transformer_engine::typeToNumBits(dtype) * d) / 8; + } + return stride; +} + +// Helper function to determine unpacked shape for KV packed tensor +std::vector calculate_kv_unpacked_shape(const transformer_engine::Tensor *kv_tensor, + NVTE_QKV_Layout_Group layout_group, + NVTE_QKV_Format kv_format, size_t t_kv, size_t h_kv, + size_t d) { + std::vector unpacked_kv_shape; + if (kv_format == NVTE_QKV_Format::NVTE_THD) { + unpacked_kv_shape = {t_kv, h_kv, d}; + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD || + layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { + unpacked_kv_shape = {kv_tensor->data.shape[0], kv_tensor->data.shape[1], h_kv, d}; + } + return unpacked_kv_shape; +} +} // namespace + // map NVTE_QKV_Layout to NVTE_QKV_Layout_Group NVTE_QKV_Layout_Group nvte_get_qkv_layout_group(NVTE_QKV_Layout qkv_layout) { switch (qkv_layout) { @@ -135,9 +203,10 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout) { // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; const int device_id = cuda::current_device(); @@ -165,7 +234,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || @@ -175,7 +244,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // TODO (cyang): add is_training to nvte_get_fused_attn_backend // sm90: fwd d<=256, bwd d=128 only // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && head_dim_qk <= 256 && head_dim_v <= 256) || + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || @@ -183,9 +253,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset && + !requires_64bit_ragged_offset && (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && // 9.10.0: known bugs with SDPA FP8 - (cudnn_runtime_version != 91000)) { + (cudnn_runtime_version != 91000) && !return_max_logit) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -213,7 +283,8 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) { flag_m512 = true; } if ( @@ -363,7 +434,13 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // check 64-bit ragged offset support (supported_ragged_offset_size) && // 9.10.0/9.10.1: known bugs with SDPA F16 - (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001) && + // softmax type + // pre-9.13.1: vanilla + // 9.13.1+: vanilla, off-by-one, learnable + (cudnn_runtime_version >= 91301 || + (cudnn_runtime_version < 91301 && + softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { @@ -398,6 +475,28 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( " Please upgrade your cuDNN version if possible." << std::endl; } + if ((cudnn_runtime_version == 91400) && (max_seqlen_kv > 1024) && (window_size_left != -1) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-causal) and " + "max_seqlen_kv (> 1024) does not support fused attention for cuDNN 9.14.0. " + " Please upgrade your cuDNN version if possible." + << std::endl; + } + if ((cudnn_runtime_version <= 91500) && is_training && + (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && + (max_seqlen_kv % 128 != 0) && cuda_graph && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) && + (attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK)) { + backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; + std::cout << "Warning: Given combination of attention mask (non-padding)," + " max_seqlen_kv (not divisible by 128), and qkv_format (BSHD/SBHD) for" + " backward fused attention with graph capture requires cuDNN 9.15.1+. " + "Please upgrade your cuDNN version if possible." + << std::endl; + } } else { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; } @@ -405,12 +504,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( } // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); @@ -421,6 +524,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); const Tensor *input_QKV = convertNVTETensorCheck(QKV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -447,35 +551,68 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_QKV, input_Bias, - output_O, Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, - wkspace, stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_fwd(b, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - fused_attn_arbitrary_seqlen_fwd_qkvpacked( - b, h, max_seqlen, d, t, is_training, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, window_size_left, window_size_right, input_QKV, input_Bias, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens_padded, input_rng_state, wkspace, - stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state, + wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_qkvpacked(b, h, max_seqlen, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_QKV, input_output_S, output_O, - Aux_CTX_Tensors, input_cu_seqlens, input_rng_state, wkspace, - stream, handle); + // Unpack QKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_fwd(b, h, h, max_seqlen, max_seqlen, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, &Q_view, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens, + input_cu_seqlens, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -484,15 +621,16 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } } // NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -505,6 +643,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con Tensor *input_output_dP = convertNVTETensorCheck(dP); Tensor *output_dQKV = convertNVTETensorCheck(dQKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_QKV->data.shape.size(); @@ -529,33 +668,67 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const NVTEDType QKV_type = static_cast(input_QKV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_qkvpacked( - b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, input_QKV, - input_dO, output_S, output_dQKV, output_dBias, input_cu_seqlens, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_max_512_bwd(b, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_dO, output_S, + &dQ_view, &dK_view, &dV_view, output_dBias, input_cu_seqlens, + input_cu_seqlens, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_qkvpacked( - b, h, max_seqlen, d, t, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, - window_size_left, window_size_right, deterministic, input_QKV, input_O, input_dO, - input_Bias, output_S, output_dQKV, output_dBias, input_cu_seqlens, input_cu_seqlens_padded, - input_rng_state, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type, + attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view, + &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view, + &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, + input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.0 is required for BF16/FP16 fused attention " @@ -567,10 +740,26 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_qkvpacked(b, h, max_seqlen, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_QKV, input_O, input_dO, input_M, input_ZInv, - input_S, input_output_dP, output_dQKV, input_cu_seqlens, - input_rng_state, wkspace, stream, handle); + + // Unpack QKV and dQKV and call the non-packed function + const auto QKV_type = input_QKV->data.dtype; + size_t stride = calculate_qkv_stride(layout_group, QKV_type, h, d); + std::vector unpacked_shape = calculate_qkv_unpacked_shape(input_QKV, h, d); + + // Create tensor views for Q, K, V and dQ, dK, dV + Tensor Q_view = make_tensor_view(input_QKV, unpacked_shape); + Tensor K_view = make_tensor_view(input_QKV, unpacked_shape, stride); + Tensor V_view = make_tensor_view(input_QKV, unpacked_shape, 2 * stride); + + Tensor dQ_view = make_tensor_view(output_dQKV, unpacked_shape); + Tensor dK_view = make_tensor_view(output_dQKV, unpacked_shape, stride); + Tensor dV_view = make_tensor_view(output_dQKV, unpacked_shape, 2 * stride); + + fused_attn_fp8_bwd(b, h, h, max_seqlen, max_seqlen, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, &Q_view, &K_view, &V_view, input_O, input_dO, + input_M, input_ZInv, input_S, input_output_dP, &dQ_view, &dK_view, &dV_view, + input_cu_seqlens, input_cu_seqlens, input_rng_state, wkspace, stream, + handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -579,15 +768,18 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } } // NVTE fused attention FWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead. void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -600,6 +792,7 @@ void nvte_fused_attn_fwd_kvpacked( const Tensor *input_Q = convertNVTETensorCheck(Q); const Tensor *input_KV = convertNVTETensorCheck(KV); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -660,37 +853,68 @@ void nvte_fused_attn_fwd_kvpacked( const NVTEDType KV_type = static_cast(input_KV->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) - fused_attn_max_512_fwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - fused_attn_arbitrary_seqlen_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_KV, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_fwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v, + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { #if (CUDNN_VERSION >= 8900) - fused_attn_fp8_fwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, input_Q, input_KV, input_output_S, output_O, Aux_CTX_Tensors, - input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + // Unpack KV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + fused_attn_fp8_fwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, is_training, attn_scale, + dropout, qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, + input_output_S, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, + input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -699,14 +923,17 @@ void nvte_fused_attn_fwd_kvpacked( } } // NVTE fused attention BWD with packed KV +// DEPRECATED: This API is deprecated. +// Please use nvte_fused_attn_bwd with separate Q, K, V tensors instead. void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; @@ -723,6 +950,7 @@ void nvte_fused_attn_bwd_kvpacked( Tensor *output_dQ = convertNVTETensorCheck(dQ); Tensor *output_dKV = convertNVTETensorCheck(dKV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); size_t b = input_cu_seqlens_q->data.shape[0] - 1; @@ -754,36 +982,69 @@ void nvte_fused_attn_bwd_kvpacked( const NVTEDType Q_type = static_cast(input_Q->data.dtype); const NVTEDType KV_type = static_cast(input_KV->data.dtype); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd_kvpacked( - b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, bias_type, - attn_mask_type, input_Q, input_KV, input_dO, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, input_Q->data.dtype, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_dO, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, input_cu_seqlens_q, + input_cu_seqlens_kv, wkspace, stream, handle); #else NVTE_ERROR("cuDNN 8.9.1 is required for BF16/FP16 fused attention with max_seqlen<=512. \n"); #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8903) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } - fused_attn_arbitrary_seqlen_bwd_kvpacked( - b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, t_q, t_kv, attn_scale, dropout, qkv_layout, - bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, input_Q, - input_KV, input_O, input_dO, input_Bias, output_S, output_dQ, output_dKV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, - input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + // Create tensor views for dK, dV + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_arbitrary_seqlen_bwd( + b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout, + bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, + input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, + output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q, + input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, + wkspace, stream, handle); #else const char *err_msg = "cuDNN 8.9.3 is required for BF16/FP16 fused attention " @@ -795,11 +1056,25 @@ void nvte_fused_attn_bwd_kvpacked( const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); const Tensor *input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - fused_attn_fp8_bwd_kvpacked(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_KV, input_O, - input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, - output_dKV, input_cu_seqlens_q, input_cu_seqlens_kv, - input_rng_state, wkspace, stream, handle); + + // Unpack KV and dKV and call the non-packed function + const auto Q_type = input_Q->data.dtype; + NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); + size_t stride = calculate_kv_stride(layout_group, Q_type, h_kv, d); + std::vector unpacked_kv_shape = + calculate_kv_unpacked_shape(input_KV, layout_group, kv_format, t_kv, h_kv, d); + + Tensor K_view = make_tensor_view(input_KV, unpacked_kv_shape); + Tensor V_view = make_tensor_view(input_KV, unpacked_kv_shape, stride); + + Tensor dK_view = make_tensor_view(output_dKV, unpacked_kv_shape); + Tensor dV_view = make_tensor_view(output_dKV, unpacked_kv_shape, stride); + + fused_attn_fp8_bwd(b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, attn_scale, dropout, + qkv_layout, bias_type, attn_mask_type, input_Q, &K_view, &V_view, input_O, + input_dO, input_M, input_ZInv, input_S, input_output_dP, output_dQ, &dK_view, + &dV_view, input_cu_seqlens_q, input_cu_seqlens_kv, input_rng_state, wkspace, + stream, handle); #else NVTE_ERROR("cuDNN 8.9.0 is required for FP8 fused attention. \n"); #endif @@ -809,14 +1084,16 @@ void nvte_fused_attn_bwd_kvpacked( } // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); @@ -832,6 +1109,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const Tensor *input_K = convertNVTETensorCheck(K); const Tensor *input_V = convertNVTETensorCheck(V); const Tensor *input_Bias = convertNVTETensorCheck(Bias); + const Tensor *input_SoftmaxOffset = convertNVTETensorCheck(SoftmaxOffset); Tensor *input_output_S = convertNVTETensorCheck(S); Tensor *output_O = convertNVTETensorCheck(O); Tensor *wkspace = convertNVTETensor(workspace); @@ -886,8 +1164,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -902,11 +1181,12 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #if (CUDNN_VERSION >= 8900) fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, - page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, - input_Q, input_K, input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, - input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle); + page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, + return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type, + window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias, + input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv, + input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k, + input_page_table_v, input_rng_state, wkspace, stream, handle); #else NVTE_ERROR( "cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n"); @@ -928,14 +1208,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -953,6 +1234,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso Tensor *output_dK = convertNVTETensorCheck(dK); Tensor *output_dV = convertNVTETensorCheck(dV); Tensor *output_dBias = convertNVTETensorCheck(dBias); + Tensor *output_dSoftmaxOffset = convertNVTETensorCheck(dSoftmaxOffset); Tensor *wkspace = convertNVTETensor(workspace); auto ndim = input_Q->data.shape.size(); @@ -978,8 +1260,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso const NVTEDType KV_type = static_cast(input_K->data.dtype); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { #if (CUDNN_VERSION >= 8901) @@ -993,19 +1276,22 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso #endif } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { #if (CUDNN_VERSION >= 8900) - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor *input_Bias, *input_rng_state; + size_t i = 0; + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + Tensor *input_Bias, *input_SoftmaxOffset; if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - } else { - input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + input_Bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + input_SoftmaxOffset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); } fused_attn_arbitrary_seqlen_bwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, window_size_left, window_size_right, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_Bias, output_S, output_dQ, output_dK, - output_dV, output_dBias, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, + qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, + deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias, + input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle); #else const char *err_msg = diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 4e6c3c858..14468b543 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -53,11 +53,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t num_pages_k, int64_t num_pages_v, int64_t page_size_k, int64_t page_size_v, int64_t max_pages_per_seq_k, int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, void *devPtrQ, void *devPtrK, void *devPtrV, void *devPtrBias, - void *devPtrSoftmaxStats, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, - void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, + bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK, + void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2, + void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, + void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -75,6 +76,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (is_training && dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -98,35 +100,42 @@ void fused_attn_arbitrary_seqlen_fwd_impl( s_q = is_ragged_q ? max_t_q : s_q; s_kv = is_ragged_kv ? max_t_kv : s_kv; } - const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; + bool generate_stats = !return_max_logit; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - num_pages_k, - num_pages_v, - page_size_k, - page_size_v, - max_pages_per_seq_k, - max_pages_per_seq_v, - bias_b, - bias_h, - scaling_factor, - is_training, - dropout_probability, - layout, - bias_type, - mask_type, - window_size_left, - window_size_right, - true, - tensorType, - tensorType}; + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + num_pages_k, + num_pages_v, + page_size_k, + page_size_v, + max_pages_per_seq_k, + max_pages_per_seq_v, + bias_b, + bias_h, + scaling_factor, + is_training, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + true, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + return_max_logit, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -136,8 +145,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // V std::shared_ptr, // attn_scale std::shared_ptr, // O - std::shared_ptr, // Stats + std::shared_ptr, // S1 + std::shared_ptr, // S2 std::shared_ptr, // bias + std::shared_ptr, // softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // page_table_k @@ -168,7 +179,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); - std::shared_ptr Q, K, V, attn_scale; + std::shared_ptr Q, K, V, attn_scale, softmax_offset; std::shared_ptr bias, seq_q, seq_kv; std::shared_ptr page_table_k, page_table_v; std::shared_ptr offset_q, offset_k, offset_v, offset_o, @@ -238,6 +249,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") .set_is_inference(false) + .set_generate_stats(generate_stats) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) .set_attn_scale(attn_scale); @@ -302,7 +314,45 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } - auto [O, Stats] = mha_graph->sdpa(Q, K, V, sdpa_options); + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_options.set_sink_token(softmax_offset); + } + + std::shared_ptr Max, Sum_Exp; + if (is_ragged_q && cudnn_runtime_version >= 90600) { + offset_stats = + mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("offset_stats") + .set_dim({b + 1, 1, 1, 1}) + .set_stride({1, 1, 1, 1}) + .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); + } + if (return_max_logit) { + Max = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Max") + .set_dim({b, h, s_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("Sum_Exp") + .set_dim({b, h, s_q, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Max->set_stride({h * s_q, s_q, 1, 1}); + Sum_Exp->set_stride({h * s_q, s_q, 1, 1}); + } + sdpa_options.set_logit_max(Max); + sdpa_options.set_score_sum_exp(Sum_Exp); + } + + auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d_v, o_stride.data(), layout, @@ -317,17 +367,13 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - offset_stats = - mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("offset_stats") - .set_dim({b + 1, 1, 1, 1}) - .set_stride({1, 1, 1, 1}) - .set_data_type(get_cudnn_fe_dtype(ragged_offset_type))); - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); + if (!return_max_logit) { + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); + } } std::tuple, // Q @@ -336,8 +382,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr, // attn_scale std::shared_ptr> // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); - auto Stats_tuple = std::make_tuple(Stats); + auto Stats_tuple = + generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); + auto softmax_offset_tuple = + is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto page_table_tuple = is_paged_kv ? std::make_tuple(page_table_k, page_table_v) @@ -358,17 +407,18 @@ void fused_attn_arbitrary_seqlen_fwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = std::tuple_cat( - std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, padding_tuple, - page_table_tuple, offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = + std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, Stats_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, page_table_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, Q, K, V, attn_scale, O, Stats, bias, seq_q, seq_kv, page_table_k, page_table_v, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_fprop_cache, descriptor); + auto [mha_graph, Q, K, V, attn_scale, O, S1, S2, bias, softmax_offset, seq_q, seq_kv, + page_table_k, page_table_v, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_fprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -399,9 +449,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( // Build variant pack std::unordered_map, void *> variant_pack = { - {Q, devPtrQ}, {K, devPtrK}, - {V, devPtrV}, {attn_scale, &scaling_factor}, - {O, devPtrO}, {Stats, devPtrSoftmaxStats}}; + {Q, devPtrQ}, {K, devPtrK}, {V, devPtrV}, {attn_scale, &scaling_factor}, + {O, devPtrO}, {S1, devPtrS1}}; + + if (return_max_logit) { + variant_pack[S2] = devPtrS2; + } if (is_bias) { variant_pack[bias] = devPtrBias; @@ -473,6 +526,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( variant_pack[dropout_seed] = devPtrDropoutSeed; variant_pack[dropout_offset] = devPtrDropoutOffset; } + + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -483,14 +541,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, void *devPtrQ, void *devPtrKTranspose, - void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, - void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, - void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, - cudnn_frontend::DataType_t tensorType, void *workspace, size_t *workspace_size, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ, + void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats, + void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV, + void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, + void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, + void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, + void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -506,6 +564,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( is_causal = true; is_bottom_right = false; } + bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX); bool is_dropout = (dropout_probability != 0.0f); NVTE_QKV_Format q_format = nvte_get_q_format(layout); NVTE_QKV_Format kv_format = nvte_get_kv_format(layout); @@ -537,32 +596,38 @@ void fused_attn_arbitrary_seqlen_bwd_impl( const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; try { - FADescriptor_v1 descriptor{b, - h, - hg, - s_q, - s_kv, - d_qk, - d_v, - 0, - 0, - 0, - 0, - 0, - 0, - bias_b, - bias_h, - scaling_factor, - true, - dropout_probability, - layout, - bias_type, - mask_type, - window_size_left, - window_size_right, - deterministic, - tensorType, - tensorType}; + FADescriptor_v1 descriptor{ + b, + h, + hg, + s_q, + s_kv, + d_qk, + d_v, + 0, + 0, + 0, + 0, + 0, + 0, + bias_b, + bias_h, + scaling_factor, + true, + dropout_probability, + layout, + bias_type, + mask_type, + softmax_type, + window_size_left, + window_size_right, + deterministic, + tensorType, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + false, + }; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -579,6 +644,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr, // dV std::shared_ptr, // bias std::shared_ptr, // dBias + std::shared_ptr, // softmax_offset + std::shared_ptr, // d_softmax_offset std::shared_ptr, // seq_q std::shared_ptr, // seq_kv std::shared_ptr, // offset_q @@ -608,7 +675,8 @@ void fused_attn_arbitrary_seqlen_bwd_impl( .set_compute_data_type(fe::DataType_t::FLOAT); std::shared_ptr q, k, v, o, dO, stats, attn_scale; - std::shared_ptr bias, dBias, seq_q, seq_kv; + std::shared_ptr bias, dBias, softmax_offset, d_softmax_offset, + seq_q, seq_kv; std::shared_ptr offset_q, offset_k, offset_v, offset_o, offset_stats; std::shared_ptr dropout_seed, dropout_offset; @@ -771,6 +839,21 @@ void fused_attn_arbitrary_seqlen_bwd_impl( sdpa_backward_options.set_dropout(dropout_probability, dropout_seed, dropout_offset); } + if (is_softmax_offset) { + softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_sink_token(softmax_offset); + d_softmax_offset = mha_graph->tensor(fe::graph::Tensor_attributes() + .set_name("d_softmax_offset") + .set_dim({1, h, 1, 1}) + .set_stride({h, 1, 1, 1}) + .set_data_type(fe::DataType_t::FLOAT)); + sdpa_backward_options.set_dsink_token(d_softmax_offset); + } + auto [dQ, dK, dV] = mha_graph->sdpa_backward(q, k, v, o, dO, stats, sdpa_backward_options); dQ->set_output(true).set_dim({b, h, s_q, d_qk}).set_stride(q_stride); @@ -796,6 +879,9 @@ void fused_attn_arbitrary_seqlen_bwd_impl( std::shared_ptr> // dV key_tensors_tuple = std::make_tuple(q, k, v, o, dO, stats, attn_scale, dQ, dK, dV); auto bias_tuple = is_bias ? std::make_tuple(bias, dBias) : std::make_tuple(nullptr, nullptr); + auto softmax_offset_tuple = is_softmax_offset + ? std::make_tuple(softmax_offset, d_softmax_offset) + : std::make_tuple(nullptr, nullptr); auto padding_tuple = is_padding ? std::make_tuple(seq_q, seq_kv) : std::make_tuple(nullptr, nullptr); auto offset_qo_tuple = @@ -814,17 +900,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( NVTE_CHECK_CUDNN_FE(mha_graph->check_support(handle)); NVTE_CHECK_CUDNN_FE(mha_graph->build_plans(handle)); - auto return_tuple = - std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, padding_tuple, - offset_qo_tuple, offset_kv_tuple, offset_s_tuple, dropout_tuple); + auto return_tuple = std::tuple_cat(std::make_tuple(mha_graph), key_tensors_tuple, bias_tuple, + softmax_offset_tuple, padding_tuple, offset_qo_tuple, + offset_kv_tuple, offset_s_tuple, dropout_tuple); cache.insert({descriptor, return_tuple}); return return_tuple; }; - auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, seq_q, seq_kv, - offset_q, offset_o, offset_k, offset_v, offset_stats, dropout_seed, dropout_offset] = - get_graph(sdpa_f16_bprop_cache, descriptor); + auto [mha_graph, q, k, v, o, dO, stats, attn_scale, dQ, dK, dV, bias, dBias, softmax_offset, + d_softmax_offset, seq_q, seq_kv, offset_q, offset_o, offset_k, offset_v, offset_stats, + dropout_seed, dropout_offset] = get_graph(sdpa_f16_bprop_cache, descriptor); // Exit to request upper level API to allocate memory if needed // n.b. Care should be taken to align each of the added worksapce tensors to their type. @@ -938,6 +1024,11 @@ void fused_attn_arbitrary_seqlen_bwd_impl( variant_pack[dropout_offset] = devPtrDropoutOffset; } + if (is_softmax_offset) { + variant_pack[softmax_offset] = devPtrSoftmaxOffset; + variant_pack[d_softmax_offset] = devPtrdSoftmaxOffset; + } + NVTE_CHECK_CUDNN_FE(mha_graph->execute(handle, variant_pack, workspace)); } catch (cudnn_frontend::cudnnException &e) { NVTE_ERROR(e.what()); @@ -946,465 +1037,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl( } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens = 0; - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen, max_seqlen}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (qkv_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, 0, 0, 0, 0, 0, 0, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, nullptr, nullptr, devPtrSeqOffsets, - devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, - handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_QKV->data.dtype; - void *devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens = 0; - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if (qkv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - max_tokens = get_max_tokens(num_tokens); - } - - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - void *devPtrSeqOffsets = cu_seqlens_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, head_dim, - max_batch_size, max_tokens, max_tokens, bias_b, bias_h, attn_scale, p_dropout, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, devPtrK, - devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, - devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlens, devPtrCuSeqlens, - devPtrSeqOffsets, devPtrSeqOffsets, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - bias_b = input_Bias->data.shape[0]; - bias_h = input_Bias->data.shape[1]; - } - - void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - if (Aux_CTX_Tensors->size == 0) { - const auto cudnn_runtime_version = cudnnGetVersion(); - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; - } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_fwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, - page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - const auto QKV_type = input_Q->data.dtype; - void *devPtrQ = input_Q->data.dptr; - void *devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void *devPtrK = devPtrKV; - void *devPtrV = static_cast(static_cast(devPtrKV) + stride); - - void *devPtrO = input_O->data.dptr; - void *devPtrdO = input_dO->data.dptr; - void *devPtrBias = nullptr; - void *devPtrdBias = nullptr; - size_t bias_b = 0; - size_t bias_h = 0; - if ((bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) && (bias_type != NVTE_Bias_Type::NVTE_ALIBI)) { - devPtrBias = input_Bias->data.dptr; - devPtrdBias = output_dBias->data.dptr; - bias_b = output_dBias->data.shape[0]; - bias_h = output_dBias->data.shape[1]; - } - - size_t max_batch_size = 0; - size_t max_tokens_q = 0; - size_t max_tokens_kv = 0; - NVTE_QKV_Format q_format = nvte_get_q_format(qkv_layout); - NVTE_QKV_Format kv_format = nvte_get_kv_format(qkv_layout); - if (q_format == NVTE_QKV_Format::NVTE_THD || kv_format == NVTE_QKV_Format::NVTE_THD) { - max_batch_size = get_max_batch_size(batch); - } - if (q_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_q = get_max_tokens(num_tokens_q); - } - if (kv_format == NVTE_QKV_Format::NVTE_THD) { - max_tokens_kv = get_max_tokens(num_tokens_kv); - } - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdKV = output_dKV->data.dptr; - void *devPtrdK = devPtrdKV; - void *devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - - void *devPtrSoftmaxStats = nullptr; - devPtrSoftmaxStats = output_S->data.dptr; - - void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; - void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; - void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; - void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_arbitrary_seqlen_bwd_impl( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, head_dim, - max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, - devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; @@ -1416,7 +1059,8 @@ void fused_attn_arbitrary_seqlen_fwd( void *devPtrK = input_K->data.dptr; void *devPtrV = input_V->data.dptr; void *devPtrO = output_O->data.dptr; - void *devPtrS = nullptr; + void *devPtrS1 = nullptr; + void *devPtrS2 = nullptr; void *devPtrBias = nullptr; size_t bias_b = 0; size_t bias_h = 0; @@ -1425,13 +1069,17 @@ void fused_attn_arbitrary_seqlen_fwd( bias_b = input_Bias->data.shape[0]; bias_h = input_Bias->data.shape[1]; } + void *devPtrSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; void *devPtrSeqOffsetsQ = cu_seqlens_q_padded->data.dptr; void *devPtrSeqOffsetsKV = cu_seqlens_kv_padded->data.dptr; - void *devPtrPageTableK = page_table_k->data.dptr; - void *devPtrPageTableV = page_table_v->data.dptr; + void *devPtrPageTableK = page_table_k ? page_table_k->data.dptr : nullptr; + void *devPtrPageTableV = page_table_v ? page_table_v->data.dptr : nullptr; size_t max_batch_size = 0; size_t max_tokens_q = 0; @@ -1446,29 +1094,28 @@ void fused_attn_arbitrary_seqlen_fwd( max_tokens_kv = get_max_tokens(num_tokens_kv); } + size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { - Aux_CTX_Tensors->size = 3; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Max->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; + output_Max->data.shape = {max_tokens_q, num_attn_heads, 1}; } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } - output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = nullptr; - output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; - output_bias->data.dtype = QKV_type; + output_Max->data.dtype = DType::kFloat32; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_Sum_Exp->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_Sum_Exp->data.shape = {max_tokens_q, num_attn_heads, 1}; + } else { + output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_Sum_Exp->data.dtype = DType::kFloat32; } else { - Aux_CTX_Tensors->size = 2; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_S->data.dptr = nullptr; if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { output_S->data.shape = {max_tokens_q, num_attn_heads, 1}; @@ -1476,23 +1123,48 @@ void fused_attn_arbitrary_seqlen_fwd( output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_S->data.dtype = DType::kFloat32; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; } - } else if (Aux_CTX_Tensors->size == 2) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - output_rng_state->data.dptr = rng_state->data.dptr; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); + + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_rng_state->data.dptr = nullptr; + output_rng_state->data.shape = {2}; + output_rng_state->data.dtype = DType::kInt64; + + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = nullptr; + output_bias->data.shape = {bias_b, bias_h, max_seqlen_q, max_seqlen_kv}; + output_bias->data.dtype = QKV_type; + } + + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = nullptr; + output_softmax_offset->data.shape = {1, num_attn_heads, 1, 1}; + output_softmax_offset->data.dtype = DType::kFloat32; + } + + Aux_CTX_Tensors->size = i; + } else if (Aux_CTX_Tensors->size >= 2) { + if (return_max_logit) { + Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_Max->data.dptr; + Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS2 = output_Sum_Exp->data.dptr; + } else { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + } + Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; - Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_bias->data.dptr = devPtrBias; + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { + Tensor *output_bias = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_bias->data.dptr = devPtrBias; + } + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + Tensor *output_softmax_offset = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_softmax_offset->data.dptr = devPtrSoftmaxOffset; + } } else { NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); } @@ -1507,11 +1179,12 @@ void fused_attn_arbitrary_seqlen_fwd( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - devPtrQ, devPtrK, devPtrV, devPtrBias, devPtrS, devPtrO, devPtrDropoutSeed, - devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, - devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias, + devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset, + devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ, + devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, + stream, handle); if (workspace_size > 0) { if (workspace->data.dptr == nullptr) { @@ -1532,13 +1205,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto QKV_type = input_Q->data.dtype; void *devPtrQ = input_Q->data.dptr; @@ -1577,6 +1251,12 @@ void fused_attn_arbitrary_seqlen_bwd( void *devPtrdV = output_dV->data.dptr; void *devPtrSoftmaxStats = nullptr; devPtrSoftmaxStats = output_S->data.dptr; + void *devPtrSoftmaxOffset = nullptr; + void *devPtrdSoftmaxOffset = nullptr; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; + devPtrdSoftmaxOffset = output_dSoftmaxOffset->data.dptr; + } void *devPtrCuSeqlensQ = cu_seqlens_q->data.dptr; void *devPtrCuSeqlensKV = cu_seqlens_kv->data.dptr; @@ -1592,9 +1272,10 @@ void fused_attn_arbitrary_seqlen_bwd( fused_attn_arbitrary_seqlen_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, window_size_left, window_size_right, deterministic, devPtrQ, - devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdBias, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, + qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, + deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias, + devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias, + devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h index e1a20274f..872b798bb 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.h @@ -18,58 +18,17 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -void fused_attn_arbitrary_seqlen_fwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, size_t num_tokens, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQKV, Tensor *output_dBias, - const Tensor *cu_seqlens, const Tensor *cu_seqlens_padded, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_fwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, - size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, - const Tensor *input_KV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, - const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_arbitrary_seqlen_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, size_t num_tokens_q, size_t num_tokens_kv, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - bool deterministic, const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, - const Tensor *input_dO, const Tensor *input_Bias, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_arbitrary_seqlen_fwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, size_t num_pages_k, size_t num_pages_v, size_t page_size_k, size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, int64_t window_size_left, int64_t window_size_right, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, + bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias, + const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); @@ -77,13 +36,14 @@ void fused_attn_arbitrary_seqlen_bwd( size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q, size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias, - Tensor *output_S, Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, - const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded, - const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q, + const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, + const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S, + Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias, + Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, + const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); #endif // CUDNN_VERSION >= 8900 } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu index 89528fa3c..1028df645 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu @@ -1215,150 +1215,6 @@ void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv } // namespace fused_attn using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd_qkvpacked( - size_t batch, size_t num_head, size_t max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - const auto stride = 2 * num_head * head_dim; - - void *devPtrQ = static_cast(devPtrQKV); - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrBias = static_cast(input_Bias->data.dptr); - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, max_seqlen, max_seqlen}; - output_S->data.dtype = input_QKV->data.dtype; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devPtrCuSeqlen = cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - const DType QKV_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devPtrCuSeqlen, devPtrCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(QKV_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - NVTE_CHECK(bias_type == NVTE_Bias_Type::NVTE_NO_BIAS || - bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS, - "NVTE_PRE_SCALE_BIAS is not implemented in fused_attn_max_512."); - - // Q shape is [b, s, h, d] - void *devPtrQ = input_Q->data.dptr; - - // KV shape is [b, s, 2, h, d] - const auto stride = 2 * num_head * head_dim; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -1429,126 +1285,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, } } -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - // QKV shape is [b, s, 3, h, d] - void *devPtrQKV = input_QKV->data.dptr; - - auto stride = 2 * num_head * head_dim; - void *devPtrQ = devPtrQKV; - void *devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void *devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQKV shape is [b, s, 3, h, d] - void *devPtrdQKV = output_dQKV->data.dptr; - void *devPtrdQ = devPtrdQKV; - void *devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void *devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrCuSeqlens = cu_seqlens->data.dptr; - - const auto qkv_type = input_QKV->data.dtype; - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl(batch, num_head, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, mask_type, bias_type, devPtrQ, devPtrK, - devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdS, - devPtrdBias, devPtrCuSeqlens, devPtrCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(qkv_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - // Q shape is [b, s, h, d] - // KV shape is [b, s, 2, h, d] - auto stride = 2 * num_head * head_dim; - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_KV->data.dptr; - void *devPtrV = static_cast(static_cast(devPtrK) + stride); - - void *devPtrdO = input_dO->data.dptr; - - // dQ shape is [b, s, h, d] - // dKV shape is [b, s, 2, h, d] - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dKV->data.dptr; - void *devPtrdV = static_cast(static_cast(devPtrdK) + stride); - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_KV->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h index 171fe846c..57b7afcf4 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h @@ -18,25 +18,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8901) -void fused_attn_max_512_fwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_size, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_Bias, - Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors, - const Tensor *cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_fwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, bool is_training, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, @@ -47,24 +28,6 @@ void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); -void fused_attn_max_512_bwd_qkvpacked(size_t batch, size_t num_head, size_t max_seqlen, - size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_QKV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQKV, - Tensor *output_dBias, const Tensor *cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd_kvpacked(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, - const Tensor *input_dO, Tensor *output_S, Tensor *output_dQ, - Tensor *output_dKV, Tensor *output_dBias, - const Tensor *q_cu_seqlens, const Tensor *kv_cu_seqlens, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d7f098376..5d806290a 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1658,8 +1658,9 @@ void fused_attn_fp8_fwd_impl_v1( void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); bool is_alibi = (bias_type == NVTE_Bias_Type::NVTE_ALIBI); @@ -1672,6 +1673,13 @@ void fused_attn_fp8_fwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (o_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + o_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports O tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); try { FADescriptor_v1 descriptor{b, @@ -1695,11 +1703,15 @@ void fused_attn_fp8_fwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, true, - fwd_tensor_type, - fwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + cudnn_frontend::DataType_t::NOT_SET, + cudnn_frontend::DataType_t::NOT_SET, + false}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -1738,7 +1750,7 @@ void fused_attn_fp8_fwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -1786,7 +1798,13 @@ void fused_attn_fp8_fwd_impl_v1( descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); - scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + + if (is_delayed_scaling) { + scale_o = mha_graph->tensor_like(descale_q, "Scale_O"); + } + if (is_current_scaling) { + scale_o = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() @@ -1838,11 +1856,12 @@ void fused_attn_fp8_fwd_impl_v1( std::vector o_stride(4); generateMatrixStrides(b, h, s_q, s_kv, d, o_stride.data(), layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride); + O->set_output(true).set_dim({b, h, s_q, d}).set_stride(o_stride).set_data_type(o_tensor_type); amax_o->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); + amax_s->set_output(true) .set_dim({1, 1, 1, 1}) .set_stride({1, 1, 1, 1}) @@ -1915,13 +1934,16 @@ void fused_attn_fp8_fwd_impl_v1( {descale_v, devPtrDescaleV}, {descale_s, devPtrDescaleS}, {scale_s, devPtrScaleS}, - {scale_o, devPtrScaleO}, {attn_scale, &scaling_factor}, {O, devPtrO}, {amax_s, devPtrAmaxS}, {amax_o, devPtrAmaxO}, {Stats, devPtrM}}; + if (is_delayed_scaling) { + variant_pack[scale_o] = devPtrScaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; } */ @@ -1962,8 +1984,9 @@ void fused_attn_fp8_bwd_impl_v1( void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnn_frontend::DataType_t fwd_tensor_type, - cudnn_frontend::DataType_t bwd_tensor_type, void* workspace, size_t* workspace_size, + void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, + cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, + cudnn_frontend::DataType_t dqkv_tensor_type, void* workspace, size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -1977,6 +2000,15 @@ void fused_attn_fp8_bwd_impl_v1( auto bias_h = h; NVTE_CHECK(~is_bias, "FP8 fused attention does not support pre/post_scale_bias yet!"); NVTE_CHECK(~is_alibi, "FP8 fused attention does not support ALiBi yet!"); + bool is_current_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::HALF || + dqkv_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); + bool is_delayed_scaling = (dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E4M3 || + dqkv_tensor_type == cudnn_frontend::DataType_t::FP8_E5M2); + NVTE_CHECK(is_current_scaling || is_delayed_scaling, + "FP8 fused attention only supports dQKV tensor in kFloat16, kBFloat16, kFloat8E4M3 or " + "kFloat8E5M2!"); + bool is_O_in_F16 = (o_tensor_type == cudnn_frontend::DataType_t::HALF || + o_tensor_type == cudnn_frontend::DataType_t::BFLOAT16); try { FADescriptor_v1 descriptor{b, @@ -2000,11 +2032,15 @@ void fused_attn_fp8_bwd_impl_v1( layout, bias_type, mask_type, + NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX, 0, 0, false, - fwd_tensor_type, - bwd_tensor_type}; + qkv_tensor_type, + o_tensor_type, + do_tensor_type, + dqkv_tensor_type, + false}; namespace fe = cudnn_frontend; using graph_and_tensors = @@ -2057,7 +2093,7 @@ void fused_attn_fp8_bwd_impl_v1( // otherwise, build the op_graph and the plan. Then update cache auto mha_graph = std::make_shared(); - mha_graph->set_io_data_type(fwd_tensor_type) + mha_graph->set_io_data_type(qkv_tensor_type) .set_intermediate_data_type(fe::DataType_t::FLOAT) .set_compute_data_type(fe::DataType_t::FLOAT); @@ -2097,7 +2133,8 @@ void fused_attn_fp8_bwd_impl_v1( o = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("O") .set_dim({b, h, s_q, d}) - .set_stride(o_stride)); + .set_stride(o_stride) + .set_data_type(o_tensor_type)); dO = mha_graph->tensor(fe::graph::Tensor_attributes() .set_name("dO") .set_dim({b, h, s_q, d}) @@ -2123,14 +2160,26 @@ void fused_attn_fp8_bwd_impl_v1( descale_k = mha_graph->tensor_like(descale_q, "Descale_q"); descale_v = mha_graph->tensor_like(descale_q, "Descale_V"); descale_s = mha_graph->tensor_like(descale_q, "Descale_S"); - descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); descale_dP = mha_graph->tensor_like(descale_q, "Descale_dP"); + if (is_O_in_F16) { + descale_o = mha_graph->tensor(1.0f); + } else { + descale_o = mha_graph->tensor_like(descale_q, "Descale_O"); + } descale_dO = mha_graph->tensor_like(descale_q, "Descale_dO"); scale_s = mha_graph->tensor_like(descale_q, "Scale_S"); scale_dP = mha_graph->tensor_like(descale_q, "Scale_dP"); - scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); - scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); - scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + + if (is_delayed_scaling) { + scale_dQ = mha_graph->tensor_like(descale_q, "Scale_dQ"); + scale_dK = mha_graph->tensor_like(descale_q, "Scale_dK"); + scale_dV = mha_graph->tensor_like(descale_q, "Scale_dV"); + } + if (is_current_scaling) { + scale_dQ = mha_graph->tensor(1.0f); + scale_dK = mha_graph->tensor(1.0f); + scale_dV = mha_graph->tensor(1.0f); + } fe::graph::SDPA_fp8_backward_attributes sdpa_backward_options; sdpa_backward_options = fe::graph::SDPA_fp8_backward_attributes() @@ -2212,10 +2261,10 @@ void fused_attn_fp8_bwd_impl_v1( .set_stride({1, 1, 1, 1}) .set_data_type(fe::DataType_t::FLOAT); - dO->set_data_type(bwd_tensor_type); - dQ->set_data_type(bwd_tensor_type); - dK->set_data_type(bwd_tensor_type); - dV->set_data_type(bwd_tensor_type); + dO->set_data_type(do_tensor_type); + dQ->set_data_type(dqkv_tensor_type); + dK->set_data_type(dqkv_tensor_type); + dV->set_data_type(dqkv_tensor_type); std::tuple, // q std::shared_ptr, // k @@ -2296,14 +2345,10 @@ void fused_attn_fp8_bwd_impl_v1( {descale_q, devPtrDescaleQ}, {descale_k, devPtrDescaleK}, {descale_v, devPtrDescaleV}, - {descale_o, devPtrDescaleO}, {descale_dO, devPtrDescaledO}, {descale_s, devPtrDescaleS}, {descale_dP, devPtrDescaledP}, {scale_s, devPtrScaleS}, - {scale_dQ, devPtrScaledQ}, - {scale_dK, devPtrScaledK}, - {scale_dV, devPtrScaledV}, {scale_dP, devPtrScaledP}, {dQ, devPtrdQ}, {dK, devPtrdK}, @@ -2314,6 +2359,15 @@ void fused_attn_fp8_bwd_impl_v1( {amax_dP, devPtrAmaxdP}, }; + if (is_delayed_scaling) { + variant_pack[scale_dQ] = devPtrScaledQ; + variant_pack[scale_dK] = devPtrScaledK; + variant_pack[scale_dV] = devPtrScaledV; + } + if (!is_O_in_F16) { + variant_pack[descale_o] = devPtrDescaleO; + } + /* if (is_bias) { variant_pack[bias] = devPtrBias; if ((bias_b == 1) && (bias_h == h)) { @@ -2353,410 +2407,6 @@ void fused_attn_fp8_bwd_impl_v1( } // namespace fused_attn #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = static_cast(devPtrQKV); - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, - devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, - devPtrAmaxS, devPtrcuSeqlens, devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_QKV, const Tensor* input_O, const Tensor* input_dO, const Tensor* input_M, - const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQKV, const Tensor* cu_seqlens, const Tensor* rng_state, Tensor* workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_QKV->data.dtype; - const DType dQKV_type = output_dQKV->data.dtype; - void* devPtrQKV = input_QKV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - stride = (typeToNumBits(QKV_type) * num_attn_heads * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_H3D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrQ = devPtrQKV; - void* devPtrK = static_cast(static_cast(devPtrQKV) + stride); - void* devPtrV = static_cast(static_cast(devPtrQKV) + 2 * stride); - void* devPtrDescaleQ = input_QKV->scale_inv.dptr; - void* devPtrDescaleK = input_QKV->scale_inv.dptr; - void* devPtrDescaleV = input_QKV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQKV = output_dQKV->data.dptr; - void* devPtrdQ = devPtrdQKV; - void* devPtrdK = static_cast(static_cast(devPtrdQKV) + stride); - void* devPtrdV = static_cast(static_cast(devPtrdQKV) + 2 * stride); - void* devPtrAmaxdQ = output_dQKV->amax.dptr; - void* devPtrAmaxdK = output_dQKV->amax.dptr; - void* devPtrAmaxdV = output_dQKV->amax.dptr; - void* devPtrScaledQ = output_dQKV->scale.dptr; - void* devPtrScaledK = output_dQKV->scale.dptr; - void* devPtrScaledV = output_dQKV->scale.dptr; - - void* devPtrcuSeqlens = - reinterpret_cast(reinterpret_cast(cu_seqlens->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, devPtrcuSeqlens, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen, max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, devPtrdK, - devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, - devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, - devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlens, - devPtrcuSeqlens, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_dtype(QKV_type), - workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor* input_Q, - const Tensor* input_KV, Tensor* input_output_S, Tensor* output_O, - NVTETensorPack* Aux_CTX_Tensors, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, - Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = output_O->data.dptr; - void* devPtrAmaxO = output_O->amax.dptr; - void* devPtrScaleO = output_O->scale.dptr; - - void* devPtrM = nullptr; - void* devPtrZInv = nullptr; - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 3; - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - output_M->data.dptr = nullptr; - output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_M->data.dtype = DType::kFloat32; - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - output_rng_state->data.dptr = nullptr; - output_rng_state->data.shape = {2}; - output_rng_state->data.dtype = DType::kInt64; - } else if (Aux_CTX_Tensors->size == 3) { - Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[1]); - Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[2]); - devPtrM = output_M->data.dptr; - devPtrZInv = output_ZInv->data.dptr; - output_rng_state->data.dptr = rng_state->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void* devPtrAmaxS = input_output_S->amax.dptr; - void* devPtrScaleS = input_output_S->scale.dptr; - void* devPtrDescaleS = input_output_S->scale_inv.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, - devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, - devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor* input_Q, const Tensor* input_KV, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_M, const Tensor* input_ZInv, const Tensor* input_S, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dKV, const Tensor* cu_seqlens_q, - const Tensor* cu_seqlens_kv, const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - const DType QKV_type = input_Q->data.dtype; - const DType dQKV_type = output_dQ->data.dtype; - void* devPtrQ = input_Q->data.dptr; - void* devPtrKV = input_KV->data.dptr; - NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); - size_t stride = 0; - if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - stride = (typeToNumBits(QKV_type) * num_gqa_groups * head_dim) / 8; - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_H2D) { - stride = (typeToNumBits(QKV_type) * head_dim) / 8; - } - void* devPtrK = devPtrKV; - void* devPtrV = static_cast(static_cast(devPtrKV) + stride); - void* devPtrDescaleQ = input_Q->scale_inv.dptr; - void* devPtrDescaleK = input_KV->scale_inv.dptr; - void* devPtrDescaleV = input_KV->scale_inv.dptr; - - void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; - void* devPtrdO = input_dO->data.dptr; - void* devPtrDescaledO = input_dO->scale_inv.dptr; - - void* devPtrM = input_M->data.dptr; - void* devPtrZInv = input_ZInv->data.dptr; - - void* devPtrScaleS = input_S->scale.dptr; - void* devPtrDescaleS = input_S->scale_inv.dptr; - void* devPtrAmaxdP = input_output_dP->amax.dptr; - void* devPtrScaledP = input_output_dP->scale.dptr; - void* devPtrDescaledP = input_output_dP->scale_inv.dptr; - - void* devPtrdQ = output_dQ->data.dptr; - void* devPtrdKV = output_dKV->data.dptr; - void* devPtrdK = devPtrdKV; - void* devPtrdV = static_cast(static_cast(devPtrdKV) + stride); - void* devPtrAmaxdQ = output_dQ->amax.dptr; - void* devPtrAmaxdK = output_dKV->amax.dptr; - void* devPtrAmaxdV = output_dKV->amax.dptr; - void* devPtrScaledQ = output_dQ->scale.dptr; - void* devPtrScaledK = output_dKV->scale.dptr; - void* devPtrScaledV = output_dKV->scale.dptr; - - void* devPtrcuSeqlensQ = - reinterpret_cast(reinterpret_cast(cu_seqlens_q->data.dptr)); - void* devPtrcuSeqlensKV = - reinterpret_cast(reinterpret_cast(cu_seqlens_kv->data.dptr)); - void* devPtrDropoutSeed = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr)); - void* devPtrDropoutOffset = - reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); - if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( - batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, - p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, - devPtrO, devPtrdO, devPtrdQ, devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); - } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); - } - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } -} // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, @@ -2820,6 +2470,7 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType O_type = output_O->data.dtype; size_t workspace_size = 0; NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); @@ -2829,8 +2480,8 @@ void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou attn_scale, p_dropout, qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, - devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, - &workspace_size, stream, handle); + devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), + get_cudnn_fe_dtype(O_type), workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, is_training, attn_scale, @@ -2876,7 +2527,11 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou void* devPtrDescaleV = input_Q->scale_inv.dptr; void* devPtrO = input_O->data.dptr; - void* devPtrDescaleO = input_O->scale_inv.dptr; + const DType O_type = input_O->data.dtype; + void* devPtrDescaleO = nullptr; + if (O_type == DType::kFloat8E4M3 || O_type == DType::kFloat8E5M2) { + devPtrDescaleO = input_O->scale_inv.dptr; + } void* devPtrdO = input_dO->data.dptr; void* devPtrDescaledO = input_dO->scale_inv.dptr; @@ -2909,6 +2564,7 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou reinterpret_cast(reinterpret_cast(rng_state->data.dptr) + 1); const DType QKV_type = input_Q->data.dtype; + const DType dO_type = input_dO->data.dtype; const DType dQKV_type = output_dQ->data.dtype; size_t workspace_size = 0; @@ -2922,7 +2578,8 @@ void fused_attn_fp8_bwd(size_t batch, size_t num_attn_heads, size_t num_gqa_grou devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), - get_cudnn_fe_dtype(dQKV_type), workspace->data.dptr, &workspace_size, stream, handle); + get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), + workspace->data.dptr, &workspace_size, stream, handle); } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim, attn_scale, p_dropout, diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index 3daf45d16..c2efa2582 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -13,47 +13,6 @@ namespace transformer_engine { #if (CUDNN_VERSION >= 8900) -// fused attention FWD FP8 with packed QKV -void fused_attn_fp8_fwd_qkvpacked(size_t batch, size_t num_attn_heads, size_t max_seqlen, - size_t head_dim, bool is_training, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - -// fused attention BWD FP8 with packed QKV -void fused_attn_fp8_bwd_qkvpacked( - size_t batch, size_t num_attn_heads, size_t max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_QKV, const Tensor *input_O, const Tensor *input_dO, const Tensor *input_M, - const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQKV, const Tensor *cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -// fused attention FWD FP8 with packed KV -void fused_attn_fp8_fwd_kvpacked(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, - size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, - bool is_training, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, - const Tensor *input_KV, Tensor *input_output_S, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, - Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); - -// fused attention BWD FP8 with packed KV -void fused_attn_fp8_bwd_kvpacked( - size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, - size_t max_seqlen_kv, size_t head_dim, float attn_scale, float p_dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_KV, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_M, const Tensor *input_ZInv, const Tensor *input_S, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dKV, const Tensor *cu_seqlens_q, - const Tensor *cu_seqlens_kv, const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); - // fused attention FWD FP8 with separate Q, K, V void fused_attn_fp8_fwd(size_t batch, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 678b63691..72047a73f 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -107,23 +107,29 @@ struct FADescriptor_v1 { NVTE_QKV_Layout layout; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; + NVTE_Softmax_Type softmax_type; std::int64_t window_size_left; std::int64_t window_size_right; bool deterministic; - cudnn_frontend::DataType_t fwd_tensor_type; - cudnn_frontend::DataType_t bwd_tensor_type; + cudnn_frontend::DataType_t qkv_tensor_type; + cudnn_frontend::DataType_t o_tensor_type; + cudnn_frontend::DataType_t do_tensor_type; + cudnn_frontend::DataType_t dqkv_tensor_type; + bool generate_max_sum_exp; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, - attnScale, isTraining, dropoutProbability, layout, mask_type, window_size_left, - window_size_right, deterministic, bias_type, fwd_tensor_type, bwd_tensor_type) < + attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type, + o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining, - rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.window_size_left, - rhs.window_size_right, rhs.deterministic, rhs.bias_type, rhs.fwd_tensor_type, - rhs.bwd_tensor_type); + rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type, + rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type, + rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, + rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); } }; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp index bb5e22887..97aecf4de 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -276,11 +276,13 @@ void log_fused_attn_config( // select a backend for fused attention NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { using namespace transformer_engine; + NVTE_CHECK(!(return_max_logit || cuda_graph), "ROCm does not support return_max_logit and cuda_graph for fused_attn yet."); // by default, fused attn is enabled bool nvte_fused_attn = true; if (const char* env_p = std::getenv("NVTE_FUSED_ATTN") ) { @@ -311,6 +313,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, @@ -325,6 +328,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, @@ -339,12 +343,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // NVTE fused attention FWD with packed QKV -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool cuda_graph, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked); @@ -384,9 +390,10 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, - max_seqlen, max_seqlen, d, d, window_size_left, window_size_right); - + is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit, + cuda_graph); + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_qkvpacked( b, h, max_seqlen, d, @@ -416,15 +423,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, } // NVTE fused attention BWD with packed QKV -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream) { +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked); using namespace transformer_engine; @@ -468,8 +474,8 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, dropout, h, h, max_seqlen, - max_seqlen, d, d, window_size_left, window_size_right); + true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h, + max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)){ @@ -505,14 +511,17 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con } // NVTE fused attention FWD with packed KV -void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream) { +void nvte_fused_attn_fwd_kvpacked( + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { + NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -556,8 +565,9 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd_kvpacked( @@ -596,11 +606,12 @@ void nvte_fused_attn_fwd_kvpacked(const NVTETensor Q, const NVTETensor KV, const void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked); using namespace transformer_engine; @@ -649,9 +660,10 @@ void nvte_fused_attn_bwd_kvpacked( // fix the incompatible window size from upstream frameworks pytorch/jax std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); - NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d, d, window_size_left, window_size_right); + NVTE_Fused_Attn_Backend fused_attention_backend = + nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, + softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, + d, window_size_left, window_size_right, false, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { @@ -694,14 +706,16 @@ void nvte_fused_attn_bwd_kvpacked( // NVTE fused attention FWD with separate Q, K and V void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S, + NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + bool return_max_logit, bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_fwd); @@ -740,8 +754,9 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, - max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, + h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, + return_max_logit, cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { fused_attn_ck_fwd( @@ -780,14 +795,15 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream) { + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_flash_attn_bwd); using namespace transformer_engine; const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q); @@ -830,8 +846,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso std::tie(window_size_left, window_size_right) = check_set_window_size(attn_mask_type, std::make_pair(window_size_left, window_size_right)); NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( - true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, dropout, h_q, h_kv, max_seqlen_q, - max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right); + true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q, + h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, + cuda_graph); if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_CK) { if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI)) { diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp index a8a151b40..9109ddb15 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -44,6 +44,7 @@ bool is_aotriton_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -68,7 +69,10 @@ bool is_aotriton_backend_supported( if(!(is_no_mask_window_size || is_causal_mask_window_size)){ return false; } - + + if(softmax_type!=NVTE_VANILLA_SOFTMAX){ + return false; + } //aotriton fused attn does not support gqa mode now if(num_attn_heads!=num_gqa_groups){ return false; diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h index b016acc67..fd4dffd73 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_aotriton.h @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -23,6 +23,7 @@ bool is_aotriton_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, diff --git a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp index 7ca6fc95f..02bc9ce94 100644 --- a/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp +++ b/transformer_engine/common/fused_attn_rocm/fused_attn_ck.cpp @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -26,6 +26,7 @@ bool is_ck_backend_supported( NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, + NVTE_Softmax_Type softmax_type, float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, @@ -80,6 +81,14 @@ bool is_ck_backend_supported( return false; } + // filter based on softmax type + if(softmax_type!=NVTE_VANILLA_SOFTMAX){ + if(nvte_log_ck_config){ + std::cout<<"AITER/CK fused attn does not support learnable sink yet"< 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; - s_id_for_freqs = s_id + begin_offset; } fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq template __global__ void fused_rope_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, - const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, - const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, - const int o_stride_d) { + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s, + const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block, offset_block_dst; int cur_seqlens; @@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel( cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - s_id_for_freqs = s_id; } fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c template void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const float *freqs, const int *start_positions, + scalar_t *input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); @@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, - stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, - o_stride_d); + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, + s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, + o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *input_grads, const NVTE_QKV_Format qkv_format, - const bool interleaved, const int cp_size, const int cp_rank, const int s, - const int b, const int h, const int d, const int d2, - const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, cudaStream_t stream) { + const Tensor &start_positions, Tensor *input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);); @@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), - *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads), - qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream); + *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size, + cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, diff --git a/transformer_engine/common/gemm/config.cpp b/transformer_engine/common/gemm/config.cpp new file mode 100644 index 000000000..cf211beaf --- /dev/null +++ b/transformer_engine/common/gemm/config.cpp @@ -0,0 +1,116 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "./config.h" + +#include +#include + +#include + +#include "../util/logging.h" + +NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; } + +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written) { + // Write attribute size + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + *size_written = attr_size; + + // Return immediately if buffer is not provided + if (buf == nullptr) { + return; + } + + // Check buffer size + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + + // Write to buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + const auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(buf, &config_.bias_tensor, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(buf, &config_.dbias_tensor, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(buf, &config_.with_gelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(buf, &config_.use_split_accumulator, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(buf, &config_.sm_count, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes) { + // Check attribute and buffer + NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", + static_cast(attr), ")"); + const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; + NVTE_CHECK(size_in_bytes >= attr_size, + "Buffer is too small for matmul config attribute " + "(attribute ", + static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, + " bytes)"); + NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); + + // Read from buffer + NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); + auto &config_ = *reinterpret_cast(config); + switch (attr) { + case kNVTEMatmulConfigBiasTensor: + std::memcpy(&config_.bias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigDBiasTensor: + std::memcpy(&config_.dbias_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigWithGELUEpilogue: + std::memcpy(&config_.with_gelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigWithDGELUEpilogue: + std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size); + break; + case kNVTEMatmulConfigEpilogueAuxTensor: + std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size); + break; + case kNVTEMatmulConfigUseSplitAccumulator: + std::memcpy(&config_.use_split_accumulator, buf, attr_size); + break; + case kNVTEMatmulConfigSMCount: + std::memcpy(&config_.sm_count, buf, attr_size); + break; + default: + NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); + } +} + +void nvte_destroy_matmul_config(NVTEMatmulConfig config) { + if (config != nullptr) { + delete reinterpret_cast(config); + } +} diff --git a/transformer_engine/common/gemm/config.h b/transformer_engine/common/gemm/config.h new file mode 100644 index 000000000..54ccf06a5 --- /dev/null +++ b/transformer_engine/common/gemm/config.h @@ -0,0 +1,36 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_ +#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_ + +#include + +namespace transformer_engine { + +struct MatmulConfig { + NVTETensor bias_tensor = nullptr; + NVTETensor dbias_tensor = nullptr; + bool with_gelu_epilogue = false; + bool with_dgelu_epilogue = false; + NVTETensor epilogue_aux_tensor = nullptr; + bool use_split_accumulator = false; + int sm_count = 0; + + static constexpr size_t attr_sizes[] = { + sizeof(NVTETensor), // bias_tensor + sizeof(NVTETensor), // dbias_tensor + sizeof(bool), // with_gelu_epilogue + sizeof(bool), // with_dgelu_epilogue + sizeof(NVTETensor), // epilogue_aux_tensor + sizeof(bool), // use_split_accumulator + sizeof(int) // sm_count + }; +}; + +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_ diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9c2ca9b4c..0127a9edf 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -14,23 +14,58 @@ #include #include +#include #include +#include #include #include +#include #include "../common.h" +#include "../util/cuda_runtime.h" #include "../util/handle_manager.h" #include "../util/logging.h" #include "../util/multi_stream.h" -#include "common/util/cuda_runtime.h" +#include "./config.h" #ifndef __HIP_PLATFORM_AMD__ -#include "cutlass_grouped_gemm.cuh" +#include "./cutlass_grouped_gemm.cuh" #endif #ifndef __HIP_PLATFORM_AMD__ namespace { +/* Use CUDA const memory to store scalar 1 and 0 for cublas usage +*/ +__device__ __constant__ float one_device; +__device__ __constant__ float zero_device; + +inline float *GetScalarOne() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float one = 1.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(one_device, &one, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), one_device)); + return dev_ptr; +} + +inline float *GetScalarZero() { + static std::once_flag init_flag; + std::call_once(init_flag, []() { + float zero = 0.0f; + NVTE_CHECK_CUDA(cudaMemcpyToSymbol(zero_device, &zero, sizeof(float))); + }); + // return address by cudaGetSymbolAddress + float *dev_ptr; + NVTE_CHECK_CUDA(cudaGetSymbolAddress(reinterpret_cast(&dev_ptr), zero_device)); + return dev_ptr; +} + +__global__ __launch_bounds__(1) void set_float_kernel(float *ptr, float val) { *ptr = val; } + uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; @@ -90,6 +125,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla bool is_A_transposed = transA == CUBLAS_OP_T; bool is_B_transposed = transB == CUBLAS_OP_T; + // Set conditions for MXFP8 and NVFP4 gemm execution. + const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode); + const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode); + // Configure A matrix if (is_tensor_scaling(A.scaling_mode)) { // Unscaled or FP8 tensor scaling @@ -109,11 +148,43 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } + } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed + // data with the mirrored transpose-flag if we don't have row-wise data. + NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), + "Input A is missing column-wise usage"); + ret.A = A.columnwise_data.dptr; + ret.transA = is_A_transposed ? CUBLAS_OP_N : CUBLAS_OP_T; + ret.Atype = A.columnwise_data.dtype; + ret.A_scale_inv = A.columnwise_scale_inv.dptr; + ret.lda = is_A_transposed ? m : k; } - } else if (is_mxfp_scaling(A.scaling_mode)) { - // MXFP8 + + if (is_fp8_dtype(ret.Atype)) { + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK(ret.lda % 16 == 0, + "Leading dimension requirement on A for FP8 GEMM. Caller must pad."); + } + } else if (nvfp4) { + // NVFP4 GEMM. Either the pure NVFP4 recipe or the FWD pass of the Hybrid NVFP4/MXFP8 recipe. + + if (is_A_transposed) { + NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); + } else { + NVTE_CHECK(is_nvfp4_scaling(A.scaling_mode), + "Input A has unsupported combination of recipe and layout"); + NVTE_CHECK(A.has_columnwise_data(), "Input A is missing column-wise usage"); + } + ret.A = is_A_transposed ? A.data.dptr : A.columnwise_data.dptr; + ret.transA = CUBLAS_OP_T; // NVFP4 gemm is only supported in TN layout. + ret.Atype = is_A_transposed ? A.data.dtype : A.columnwise_data.dtype; + ret.A_scale_inv = is_A_transposed ? A.scale_inv.dptr : A.columnwise_scale_inv.dptr; + ret.lda = k; + } else if (mxfp8) { + // MXFP8 GEMM. Either for pure MXFP8 recipe or backward of Hybrid NVFP4 recipe. // Note: Row-wise and column-wise data are scaled along different // dimensions (with matrix interpreted in row-major order). + if (is_A_transposed) { NVTE_CHECK(A.has_data(), "Input A is missing row-wise usage"); } else { @@ -140,7 +211,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage NVTE_CHECK((ret.lda % 16) == 0, - "Inner dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); + "Leading dimension requirement on NVTE_BLOCK_SCALING GEMM. Caller must pad."); // Divisibility of 8 derived from FP8 (m * CTypeSize) % 16 == 0 requirement. // Smallest supported CType is 2 bytes in this scaling mode. NVTE_CHECK((m % 8) == 0, @@ -168,11 +239,37 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } + } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed + // data with the mirrored transpose-flag if we don't have row-wise data. + NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), + "Input B is missing column-wise usage"); + ret.B = B.columnwise_data.dptr; + ret.transB = is_B_transposed ? CUBLAS_OP_N : CUBLAS_OP_T; + ret.Btype = B.columnwise_data.dtype; + ret.B_scale_inv = B.columnwise_scale_inv.dptr; + ret.ldb = is_B_transposed ? k : n; } - } else if (is_mxfp_scaling(B.scaling_mode)) { - // MXFP8 - // Note: Row-wise and column-wise data are scaled along different - // dimensions (with matrix interpreted in row-major order). + + if (is_fp8_dtype(ret.Atype)) { + // Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + NVTE_CHECK(ret.ldb % 16 == 0, + "Leading dimension requirement on B for FP8 GEMM. Caller must pad."); + } + } else if (nvfp4) { + if (is_B_transposed) { + NVTE_CHECK(is_nvfp4_scaling(B.scaling_mode), + "Input B has unsupported combination of recipe and layout"); + NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); + } else { + NVTE_CHECK(B.has_data(), "Input B is missing row-wise usage"); + } + ret.B = is_B_transposed ? B.columnwise_data.dptr : B.data.dptr; + ret.transB = CUBLAS_OP_N; // NVFP4 gemm is only supported in TN layout. + ret.Btype = is_B_transposed ? B.columnwise_data.dtype : B.data.dtype; + ret.B_scale_inv = is_B_transposed ? B.columnwise_scale_inv.dptr : B.scale_inv.dptr; + ret.ldb = k; + } else if (mxfp8) { if (is_B_transposed) { NVTE_CHECK(B.has_columnwise_data(), "Input B is missing column-wise usage"); } else { @@ -231,7 +328,7 @@ namespace transformer_engine { void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void* workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void* alpha_ptr, const void* beta_ptr, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset = -1); #else // Use cublasLt @@ -241,7 +338,7 @@ using cublasHandleManager = detail::HandleManageramax.dptr != nullptr || inputB->amax.dptr != nullptr)) { + // Reserve some workspace for alpha scale + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for alpha scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + uint8_t *workspace_ptr = reinterpret_cast(workspace); + float *new_alpha_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + + // Update alpha scale on device + // Note: Compute NVFP4 tensor scales based on amaxes and then + // divide from alpha scale. This way we only need to apply NVFP4 + // tensor scales in matmul output, instead of in matmul inputs. + float old_alpha = *reinterpret_cast(alpha); // Assumed to be on CPU + TensorWrapper new_alpha_tensor(new_alpha_ptr, std::vector{1}, DType::kFloat32); + nvte_nvfp4_compute_per_tensor_scale(inputA->nvte_tensor, transa, inputB->nvte_tensor, !transb, + old_alpha, new_alpha_tensor.data(), stream); + alpha = new_alpha_ptr; + + // Make sure beta scale is on device + float old_beta = *reinterpret_cast(beta); // Assumed to be on CPU + if (old_beta == 0) { + beta = GetScalarZero(); // Device constant memory + } else if (old_beta == 1) { + beta = GetScalarOne(); // Device constant memory + } else { + // Move beta to workspace + NVTE_CHECK(workspaceSize >= 4, + "NVFP4 GEMM requires at least 4 byte workspace for beta scale, but only has ", + workspaceSize, " bytes remaining."); + workspaceSize = (workspaceSize / 4) * 4 - 4; // Remove last 4 aligned bytes + float *new_beta_ptr = reinterpret_cast(&workspace_ptr[workspaceSize]); + set_float_kernel<<<1, 1, 0, stream>>>(new_beta_ptr, old_beta); + NVTE_CHECK_CUDA(cudaGetLastError()); + beta = new_beta_ptr; + } + } const cudaDataType_t A_type = get_cuda_dtype(param.Atype); const cudaDataType_t B_type = get_cuda_dtype(param.Btype); @@ -290,16 +430,23 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "FP8 input to GEMM requires inverse of scale!"); NVTE_CHECK(!is_fp8_dtype(param.Btype) || param.B_scale_inv != nullptr, "FP8 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Atype) || param.A_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); + NVTE_CHECK(!is_fp4_dtype(param.Btype) || param.B_scale_inv != nullptr, + "FP4 input to GEMM requires inverse of scale!"); // check consistency of arguments: // if fp8 is desired, context cannot be null // fp8 + gelu fusion + fp8 aux is unavailable right now. - if (use_fp8 && gelu) { + if ((use_fp8 || use_fp4) && gelu) { NVTE_CHECK(!is_fp8_dtype(outputPreGelu->data.dtype), "fp8 Aux output for gemm + gelu fusion not supported!"); } - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); + if (is_fp4_dtype(outputD->data.dtype)) { + NVTE_ERROR("FP4 GEMM output is not supported!"); + } + if (use_fp4 && (D_type == CUDA_R_16F)) { + NVTE_ERROR("FP4 GEMM does not support FP16 output!"); } cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); @@ -339,12 +486,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &math_sm_count, sizeof(math_sm_count))); } - // set fp8 attributes -- input and output types should already be set to fp8 as appropriate - // Note: gelu fusion isn't available right now, and we don't need + // set fp8/fp4 attributes -- input and output types should already be set to fp8/fp4 + // as appropriate. Note: gelu fusion isn't available right now, and we don't need // amax(D) either (next op is high precision). - if (use_fp8) { - // Split accumulator. - const int8_t fastAccuMode = (use_split_accumulator) ? 0 : 1; + const bool mxfp8_gemm = !use_fp4 && is_mxfp8_scaling(inputA->scaling_mode); + + if (use_fp8 || use_fp4) { + // Fast accumulation is only supported for FP8. + const int8_t fastAccuMode = (use_split_accumulator) ? 0 : use_fp8; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fastAccuMode, sizeof(fastAccuMode))); @@ -353,7 +502,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, cublasLtMatmulMatrixScale_t scaling_mode_a; cublasLtMatmulMatrixScale_t scaling_mode_b; #endif // CUBLAS_VERSION >= 120800 - if ((is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode))) { + if (is_tensor_scaling(inputA->scaling_mode) && is_tensor_scaling(inputB->scaling_mode)) { void *A_scale_inverse = param.A_scale_inv; void *B_scale_inverse = param.B_scale_inv; NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, @@ -366,7 +515,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_SCALAR_32F; #endif // CUBLAS_VERSION >= 120800 - } else if ((is_mxfp_scaling(inputA->scaling_mode) && is_mxfp_scaling(inputB->scaling_mode))) { + } else if (mxfp8_gemm) { #if CUBLAS_VERSION >= 120800 NVTE_CHECK(cublas_version() >= 120800, "MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); @@ -391,6 +540,34 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #else NVTE_ERROR("MXFP8 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); +#endif // CUBLAS_VERSION >= 120800 + } else if (use_fp4) { // NVFP4 GEMM +#if CUBLAS_VERSION >= 120800 + NVTE_CHECK(cublas_version() >= 120800, + "FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version()); + // make sure alpha beta computation dtype remains fp32 by CUBLASLT_MATMUL_DESC_SCALE_TYPE + cublasDataType_t scale_type = CUDA_R_32F; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SCALE_TYPE, &scale_type, sizeof(scale_type))); + + // Set pointer mode: alpha and beta are both device pointers + // https://docs.nvidia.com/cuda/cublas/#cublasltpointermode-t + cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE; + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pointer_mode, sizeof(pointer_mode))); + + fp8e4m3 *A_scale_inverse = reinterpret_cast(param.A_scale_inv); + fp8e4m3 *B_scale_inverse = reinterpret_cast(param.B_scale_inv); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, + &A_scale_inverse, sizeof(A_scale_inverse))); + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc, + CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + &B_scale_inverse, sizeof(B_scale_inverse))); + scaling_mode_a = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; + scaling_mode_b = CUBLASLT_MATMUL_MATRIX_SCALE_VEC16_UE4M3; +#else + NVTE_ERROR("FP4 requires cuBLAS 12.8+, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif // CUBLAS_VERSION >= 120800 } else if ((inputA->scaling_mode == NVTE_BLOCK_SCALING_1D || inputA->scaling_mode == NVTE_BLOCK_SCALING_2D) && @@ -523,14 +700,11 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif -#if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ - CUBLAS_VERSION < 130000 +#else NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); @@ -557,6 +731,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, #endif } + // align the workspace to 256 B + const int required_alignment = 256; + const auto original_workspace_alignment = _getAlignment(reinterpret_cast(workspace)); + uint8_t *aligned_workspace_ptr = + reinterpret_cast(workspace) + required_alignment - original_workspace_alignment; + workspaceSize = workspaceSize - required_alignment + original_workspace_alignment; + const auto new_workspace_alignment = + _getAlignment(reinterpret_cast(aligned_workspace_ptr)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceCreate(&preference)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspaceSize, sizeof(workspaceSize))); @@ -564,7 +746,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const auto B_alignment = _getAlignment(reinterpret_cast(param.B)); const auto C_alignment = _getAlignment(reinterpret_cast(C)); const auto D_alignment = _getAlignment(reinterpret_cast(D)); - const auto workspace_alignment = _getAlignment(reinterpret_cast(workspace)); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_A_BYTES, &A_alignment, sizeof(A_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( @@ -573,8 +754,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_C_BYTES, &C_alignment, sizeof(C_alignment))); NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MIN_ALIGNMENT_D_BYTES, &D_alignment, sizeof(D_alignment))); - NVTE_CHECK(workspace_alignment % 256 == 0, - "cuBLAS workspace pointer must be aligned to 256 bytes, got ", workspace_alignment); + NVTE_CHECK(new_workspace_alignment % 256 == 0, + "cuBLAS workspace pointer must be aligned to 256 bytes, got ", + new_workspace_alignment); const auto status = cublasLtMatmulAlgoGetHeuristic(handle, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, @@ -585,16 +767,15 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, if (returnedResults == 0) NVTE_ERROR("Unable to find any suitable algorithms"); // D = alpha * (A * B) + beta * C - NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, - static_cast(&alpha), /* alpha */ - param.A, /* A */ - Adesc, param.B, /* B */ - Bdesc, static_cast(&beta), /* beta */ - C, /* C */ - Cdesc, D, /* D */ - Ddesc, &heuristicResult.algo, /* algo */ - workspace, /* workspace */ - workspaceSize, stream)); /* stream */ + NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, alpha, /* alpha */ + param.A, /* A */ + Adesc, param.B, /* B */ + Bdesc, beta, /* beta */ + C, /* C */ + Cdesc, D, /* D */ + Ddesc, &heuristicResult.algo, /* algo */ + aligned_workspace_ptr, /* workspace */ + workspaceSize, stream)); /* stream */ // Update FP8 scale-inv in output tensor // Note: This is a WAR for the case when we have fp8 output but D->scale_inv is not allocated. @@ -621,35 +802,117 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons int math_sm_count, cudaStream_t stream) { NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, - nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); +} + +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm_v2); + using namespace transformer_engine; + + // Data tensors + const Tensor *A_tensor = convertNVTETensorCheck(A); + const Tensor *B_tensor = convertNVTETensorCheck(B); + const Tensor *C_tensor = convertNVTETensorCheck(C); + Tensor *D_tensor = convertNVTETensorCheck(D); + NVTE_CHECK(C_tensor == D_tensor, + "Currently nvte_cublas_gemm_v2 does not support different C and D tensors."); + + // Workspace + void *workspace_ptr = nullptr; + size_t workspace_size = 0; + Tensor *workspace_tensor = convertNVTETensor(workspace); + if (workspace_tensor != nullptr) { + workspace_ptr = workspace_tensor->data.dptr; + workspace_size = + get_buffer_size_bytes(workspace_tensor->data.numel(), workspace_tensor->data.dtype); + } + + // Additional config + MatmulConfig config_; + if (config != nullptr) { + config_ = *reinterpret_cast(config); + } + + // Configure GEMM epilogue + const bool with_grad_epilogue = (config_.dbias_tensor != nullptr || config_.with_dgelu_epilogue); + if (with_grad_epilogue) { + NVTE_CHECK(config_.bias_tensor == nullptr && !config_.with_gelu_epilogue, + "Invalid epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ")."); + } + Tensor dummy_tensor; + Tensor *epilogue_bias_tensor = &dummy_tensor; + if (!with_grad_epilogue && config_.bias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.bias_tensor); + } else if (with_grad_epilogue && config_.dbias_tensor != nullptr) { + epilogue_bias_tensor = convertNVTETensorCheck(config_.dbias_tensor); + } + Tensor *epilogue_aux_tensor = &dummy_tensor; + if (config_.with_gelu_epilogue || config_.with_dgelu_epilogue) { + NVTE_CHECK(config_.epilogue_aux_tensor != nullptr, + "Requested epilogue (bias=", config_.bias_tensor != nullptr, + ", dbias=", config_.dbias_tensor != nullptr, ", gelu=", config_.with_gelu_epilogue, + ", dgelu=", config_.with_dgelu_epilogue, ") without providing aux tensor."); + epilogue_aux_tensor = convertNVTETensor(config_.epilogue_aux_tensor); + } + + // Launch GEMM + cublas_gemm(A_tensor, B_tensor, D_tensor, epilogue_bias_tensor, epilogue_aux_tensor, + transa ? CUBLAS_OP_T : CUBLAS_OP_N, transb ? CUBLAS_OP_T : CUBLAS_OP_N, + with_grad_epilogue, workspace_ptr, workspace_size, alpha, beta, + config_.use_split_accumulator, config_.sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, float alpha, float beta, bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { - NVTE_API_CALL(nvte_cublas_gemm_scaled); + NVTE_API_CALL(nvte_cublas_gemm); using namespace transformer_engine; + + // Tensors const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensor(D); + Tensor *outputD = convertNVTETensorCheck(D); const Tensor *biasTensor = convertNVTETensor(bias); Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); + // Check for NVFP4 + // TODO Remove once alpha scale logic is moved into cublas_gemm function + if (is_nvfp_scaling(inputA->scaling_mode) || is_nvfp_scaling(inputB->scaling_mode)) { + NVTE_ERROR("nvte_cublas_gemm does not support NVFP4 data. Use nvte_cublas_gemm_v2 instead."); + } + + // Launch GEMM cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); + &alpha, &beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -666,12 +929,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor #if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); -#endif -#if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) +#elif !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); -#endif +#else NVTE_CHECK( transformer_engine::cuda::cudart_version() >= 12020 && transformer_engine::cuda::cudart_version() < 13000, @@ -681,7 +943,7 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); -#endif //__HIP_PLATFORM_AMD__ +#endif const Tensor *inputA = convertNVTETensorCheck(A); const Tensor *inputB = convertNVTETensorCheck(B); @@ -691,13 +953,17 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor const Tensor *inputCounter = convertNVTETensor(counter); Tensor *wspace = convertNVTETensor(workspace); + const void *alpha_ptr = GetScalarOne(); + const void *beta_ptr = accumulate ? GetScalarOne() : GetScalarZero(); + NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, - n_split, gemm_producer, inputCounter, stream); + alpha_ptr, beta_ptr, use_split_accumulator, math_sm_count, m_split, n_split, + gemm_producer, inputCounter, stream); +#endif //#ifndef __HIP_PLATFORM_AMD__ } void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, @@ -719,24 +985,47 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens for (int i = 0; i < num_gemms; i++) { #ifdef __HIP_PLATFORM_AMD__ - { - const Tensor *inputA = convertNVTETensorCheck(A[i]); - const Tensor *inputB = convertNVTETensorCheck(B[i]); - Tensor *outputD = convertNVTETensorCheck(D[i]); - const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); - Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); - Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, - wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, - use_split_accumulator, math_sm_count, 0, 0, false, nullptr, - detail::get_compute_stream(i % num_streams), i % num_streams); - } + const Tensor *inputA = convertNVTETensorCheck(A[i]); + const Tensor *inputB = convertNVTETensorCheck(B[i]); + Tensor *outputD = convertNVTETensorCheck(D[i]); + const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); + Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); + Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); + + // Scales + const float alpha = 1; + const float beta = accumulate ? 1 : 0; + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], &alpha, &beta, + use_split_accumulator, math_sm_count, 0, 0, false, nullptr, + detail::get_compute_stream(i % num_streams), i % num_streams); #else - nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams)); + // Check whether GELU or dGELU epilogue is requested + Tensor *pre_gelu_tensor = convertNVTETensor(pre_gelu_out[i]); + bool with_gelu_dgelu_epilogue = + (pre_gelu_tensor != nullptr && pre_gelu_tensor->data.dptr != nullptr); + + // Construct config + MatmulConfig config; + if (grad) { + config.dbias_tensor = bias[i]; + config.with_dgelu_epilogue = with_gelu_dgelu_epilogue; + } else { + config.bias_tensor = bias[i]; + config.with_gelu_epilogue = with_gelu_dgelu_epilogue; + } + config.epilogue_aux_tensor = pre_gelu_out[i]; + config.use_split_accumulator = use_split_accumulator; + config.sm_count = math_sm_count; + + // Launch GEMM + const float alpha = 1.f; + const float beta = accumulate ? 1.f : 0.f; + nvte_cublas_gemm_v2(transa, transb, &alpha, A[i], B[i], &beta, D[i], D[i], + workspace[i % num_streams], &config, + detail::get_compute_stream(i % num_streams)); #endif } diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index fef3966a5..205a0a058 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -1,5 +1,5 @@ /************************************************************************* - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ @@ -1501,7 +1501,7 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, - float alpha, float beta, bool use_split_accumulator, int math_sm_count, + const void* alpha_ptr, const void* beta_ptr, bool use_split_accumulator, int math_sm_count, [[maybe_unused]] int m_split, [[maybe_unused]] int n_split, [[maybe_unused]] bool gemm_producer, [[maybe_unused]] const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset) @@ -1527,6 +1527,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int ldb = is_transb ? n : k; const int ldd = m; + float alpha = *reinterpret_cast(alpha_ptr); // Assumed to be on CPU + float beta = *reinterpret_cast(beta_ptr); // Assumed to be on CPU + ServiceStreamCtl ss_ctl; bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu new file mode 100644 index 000000000..9d4bec41d --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -0,0 +1,876 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#include + +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace { + +constexpr int kThreadsPerWarp = 32; +constexpr float k16x16HadamardScale = 0.25f; + +template +__device__ __forceinline__ void ldmatrix_x4_m8n8_shared_b16(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr) { + auto smem_addr = static_cast(__cvta_generic_to_shared(addr)); + if constexpr (kTranspose) { + asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } else { + asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "r"(smem_addr)); + } +} + +template +__device__ __forceinline__ void load_matrix_16x16_from_shared(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, + void* addr, uint32_t stride) { + if constexpr (kTranspose) { + asm volatile( + "wmma.load.a.sync.aligned.col.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } else { + asm volatile( + "wmma.load.a.sync.aligned.row.m16n16k16.shared::cta.bf16 " + "{%0,%1,%2,%3}, [%4], %5;\n" + : "=r"(a0), "=r"(a1), "=r"(a2), "=r"(a3) + : "l"(addr), "r"(stride)); + } +} + +template +__device__ __forceinline__ void store_matrix_16x16_to_global(uint32_t& a0, uint32_t& a1, + uint32_t& a2, uint32_t& a3, void* addr, + uint32_t stride) { + if constexpr (kTranspose) { + asm volatile("wmma.store.d.sync.aligned.col.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } else { + asm volatile("wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%0], {%1, %2, %3, %4}, %5;\n" + : + : "l"(addr), "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(stride)); + } +} + +__device__ __forceinline__ void matrix_transpose_m8_n8_b16_inplace(uint32_t& a0) { + asm volatile( + "movmatrix.sync.aligned.m8n8.trans.b16 " + "%0, %1;\n\t" + : "=r"(a0) + : "r"(a0)); +} + +__device__ __forceinline__ void unpack_max_of_packed_bf16(uint32_t& packed_bf16, float& float_dst) { + __nv_bfloat162 bf16x2 = *reinterpret_cast<__nv_bfloat162*>(&packed_bf16); + float f_a = __bfloat162float(bf16x2.x); + float f_b = __bfloat162float(bf16x2.y); + asm volatile("max.xorsign.abs.f32 %0, %1, %2;\n\t" : "=f"(float_dst) : "f"(f_a), "f"(f_b)); + float_dst = fabsf(float_dst); +} + +template +__device__ __forceinline__ void mma_m16_n16_k16_b16_b16_b16_noacc( + uint32_t& a0, uint32_t& a1, uint32_t& a2, uint32_t& a3, uint32_t& b0, uint32_t& b1, + uint32_t& b2, uint32_t& b3, uint32_t& c0, uint32_t& c1, uint32_t& c2, uint32_t& c3, + uint32_t& amax_result) { + uint32_t zero = 0; + uint32_t temp0, temp1, temp2, temp3, temp4, temp5, temp6, temp7; + asm volatile( + "wmma.mma.sync.aligned.row.row.m16n16k16.f32.bf16.bf16.f32 \n" + "{%0, %1, %2, %3, %4, %5, %6, %7}, \n" + "{%8, %9, %10, %11}, \n" + "{%12, %13, %14, %15}, \n" + "{%16, %17, %18, %19, %20, %21, %22, %23};\n\t" + : "=r"(temp0), "=r"(temp1), "=r"(temp2), "=r"(temp3), "=r"(temp4), "=r"(temp5), "=r"(temp6), + "=r"(temp7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), "r"(b0), "r"(b1), "r"(b2), "r"(b3), "r"(zero), + "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero), "r"(zero)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c0) : "r"(temp1), "r"(temp0)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c1) : "r"(temp3), "r"(temp2)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c2) : "r"(temp5), "r"(temp4)); + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" : "=r"(c3) : "r"(temp7), "r"(temp6)); + if constexpr (kCalculateAmax) { + uint32_t max_even; + uint32_t max_odd; + // Reduction tree to amax(abs(result)) into bf16x2 reg outparam. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_even) : "r"(c0), "r"(c2)); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" : "=r"(max_odd) : "r"(c1), "r"(c3)); + // N.B. mma is only called up to once per thread for identity and transpose respectively, so + // we don't have to accumulate into amax_result and can directly store into it. + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(amax_result) + : "r"(max_even), "r"(max_odd)); + } +} + +template +__device__ __forceinline__ void get_hadamard_matrix_fragment(uint32_t* had_frag_i, + uint16_t random_sign_mask, + uint32_t* had_frag_t, + uint16_t random_sign_mask_t) { + int32_t tid = threadIdx.x % 32; // Local tid + float temp_i[2]; + float temp_t[2]; +#pragma unroll + for (int i = 0; i < 2; i++) { + // i is the vertical fragment index. + // For a 16x16 matrix matrix fragment, 4 threads fill a fragment of 8 BF16 vals. + uint32_t r = i * 8 + tid / 4; + +#pragma unroll + for (int j = 0; j < 2; j++) { +#pragma unroll + for (int k = 0; k < 2; k++) { + // k is column position [0, 1] within a quad of 2 BF16s stored together in 32 bits. + // j is the column fragment idx selecting between even and odd fragments. + // j increments 8 columns by switching fragments. + uint32_t c = j * 8 + k + tid % 4 * 2; + // 1 -> -1.0f, 0 -> 1.0f + int32_t base_sign = __popc(r & c); + if constexpr (kReturnIdentity) { + int32_t sign_i; + // Because tensor cores want the dot product dimension, + // contiguous, the regular, non-inverse hadamard swaps + // signs of columns and rows for inverse. In a simple reference, + // x.reshape(-1, 16) @ sign @ H16, this would be opposite but + // (sign @ H16) is transposed in this fragment. + if constexpr (kInverseHadamardIdentity) { + sign_i = ((random_sign_mask >> r) ^ base_sign); + } else { + sign_i = ((random_sign_mask >> c) ^ base_sign); + } + temp_i[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_i << 31)); + } + if constexpr (kReturnTransposed) { + int32_t sign_t; + if constexpr (kInverseHadamardTransposed) { + sign_t = ((random_sign_mask_t >> r) ^ base_sign); + } else { + sign_t = ((random_sign_mask_t >> c) ^ base_sign); + } + temp_t[k] = copysignf(k16x16HadamardScale, __int_as_float(sign_t << 31)); + } + } + + if constexpr (kReturnIdentity) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_i[i * 2 + j]) + : "f"(temp_i[1]), "f"(temp_i[0])); + } + if constexpr (kReturnTransposed) { + asm volatile("cvt.rn.bf16x2.f32 %0, %1, %2;\n\t" + : "=r"(had_frag_t[i * 2 + j]) + : "f"(temp_t[1]), "f"(temp_t[0])); + } + } + } +} + +__device__ __forceinline__ uint32_t swizzle_128B_atom_32B(uint32_t gmem_row_idx, + uint32_t gmem_col_idx) { + uint32_t smem_row_idx = gmem_row_idx; + uint32_t xor_factor = (smem_row_idx * 2) % 8; + uint32_t smem_col_idx = gmem_col_idx ^ xor_factor; + return smem_row_idx * 8 + smem_col_idx; +} + +template +__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4], + IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg, + uint32_t& local_amax_reg, + uint32_t& local_amax_t_reg) { + uint32_t a_frag[4]; // A matrix fragment + uint32_t c_frag[4]; // Result fragment + + int warp_id = threadIdx.x / kThreadsPerWarp; + int local_rank = (threadIdx.x % kThreadsPerWarp); + + int ld_row_idx = local_rank % kHadamardDimension; + int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2; + int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx); + + uint32_t temp_amax_reg; + uint32_t temp_amax_t_reg; + + if (kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_reg) + : "r"(local_amax_reg), "r"(temp_amax_reg)); + } + + if (kReturnTransposedAmax) { + // TODO(Frank): This is not efficient, since we could directly load the + // matrix in transposed layout. + if (!kReturnIdentityAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], + b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_amax_t_reg) + : "r"(local_amax_t_reg), "r"(temp_amax_t_reg)); + } + + if (kReturnPreRhtAmax) { + if (!kReturnIdentityAmax && !kReturnTransposedAmax) { + ldmatrix_x4_m8n8_shared_b16(a_frag[0], a_frag[1], a_frag[2], a_frag[3], + reinterpret_cast(in_sh_ptr) + swizzle_idx); + } + + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[1])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[2]) + : "r"(a_frag[2]), "r"(a_frag[3])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(a_frag[0]) + : "r"(a_frag[0]), "r"(a_frag[2])); + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t" + : "=r"(local_pre_rht_amax_reg) + : "r"(a_frag[0]), "r"(local_pre_rht_amax_reg)); + } +} + +template +__device__ __host__ constexpr int NextPowerOf2() { + static_assert(kN > 0, "kN must be > 0"); + // Round up to the next power of 2 by counting leading zeros. + return 1 << (32 - __builtin_clz(kN - 1)); +} + +template +__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax, + const float transpose_amax, float* staging_for_pre_rht, + float* staging_for_identity, float* staging_for_transpose, + float* output_pre_rht_amax_ptr, + float* output_identity_amax_ptr, + float* output_transpose_amax_ptr, const int warpid) { + // intra-warp reduction + constexpr int kWarpSize = 32; + int local_rank = threadIdx.x % 32; + float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max(pre_rht_amax) : 0.0f; + float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max(identity_amax) : 0.0f; + float warp_transpose_amax = + kReturnTransposedAmax ? warp_reduce_max(transpose_amax) : 0.0f; + + // inter-warp reduction + if (threadIdx.x % 32 == 0) { + if (kReturnPreRhtAmax) { + staging_for_pre_rht[warpid] = warp_pre_rht_amax; + } + if (kReturnIdentityAmax) { + staging_for_identity[warpid] = warp_identity_amax; + } + if (kReturnTransposedAmax) { + staging_for_transpose[warpid] = warp_transpose_amax; + } + } + __syncthreads(); + constexpr int kNumWarpsPow2 = NextPowerOf2(); + if (warpid == 0) { + if (kReturnIdentityAmax) { + float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f; + identity_accum = warp_reduce_max(identity_accum); + if (local_rank == 0) { + atomicMaxFloat(output_identity_amax_ptr, identity_accum); + } + } + } + if (warpid == 1) { + if (kReturnTransposedAmax) { + float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f; + transpose_accum = warp_reduce_max(transpose_accum); + if (local_rank == 0) { + atomicMaxFloat(output_transpose_amax_ptr, transpose_accum); + } + } + } + if (warpid == 2) { + if (kReturnPreRhtAmax) { + float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f; + pre_rht_accum = warp_reduce_max(pre_rht_accum); + if (local_rank == 0) { + atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum); + } + } + } +} + +__launch_bounds__(1) __global__ void ZeroAmaxKernel(float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr) { + if (output_pre_rht_amax_ptr != nullptr) { + *output_pre_rht_amax_ptr = 0; + } + if (output_identity_amax_ptr != nullptr) { + *output_identity_amax_ptr = 0; + } + if (output_transpose_amax_ptr != nullptr) { + *output_transpose_amax_ptr = 0; + } +} + +template +__global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor_map_input, + float* __restrict__ output_pre_rht_amax_ptr, + float* __restrict__ output_identity_amax_ptr, + float* __restrict__ output_transpose_amax_ptr, + uint16_t random_sign_mask, uint16_t random_sign_mask_t, + uint64_t num_rows, uint64_t row_length) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + + static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0); + static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0); + + constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y; + constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X; + + constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp; + + const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X; + + extern __shared__ __align__(128) char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uint8_t* dshmem = reinterpret_cast((base_shmem_ptr + 127) & ~127ULL); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + IType* in_sh_0 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + IType* in_sh_1 = reinterpret_cast(dshmem); + dshmem += in_buff_size; + + IType* in_shs[2] = {in_sh_0, in_sh_1}; + + constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType); + + const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0); + + // Initialize shared memory barrier with the number of threads participating in the barrier. +#pragma nv_diag_suppress static_var_with_dynamic_init + uint64_t* mbar = reinterpret_cast(dshmem); + dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y); + + float* max_staging_identity = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_transpose = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + float* max_staging_pre_rht = reinterpret_cast(dshmem); + dshmem += sizeof(float) * kNumWarps; + + initialize_barriers(mbar, + is_master_thread); + + copy_2d_to_shared(in_shs[0], reinterpret_cast(&tensor_map_input), + input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0], + is_master_thread); + + uint32_t had_frag_i[4]; + uint32_t had_frag_t[4]; + get_hadamard_matrix_fragment( + had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t); + + float local_pre_rht_amax = 0.0; + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_pre_rht_amax_reg = *reinterpret_cast(&local_pre_rht_amax); + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + + for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) { + for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) { + int stage = STAGES_X * stage_y + stage_x; + + const int next_stage = stage + 1; + const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1; + const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y; + + if (next_stage < STAGES_X * STAGES_Y) { + const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y; + const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X; + + copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong + reinterpret_cast(&tensor_map_input), input_global_offset_X, + input_global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); + } + + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t compute_stage_x_num = + BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)); + const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y); + + const size_t in_row_stride = BUFF_DIM_X; + + IType* in_sh_ptr = in_shs[stage % 2]; + +#pragma unroll + for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) { + const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y + + threadIdx.y * kHadamardDimension); + const int in_row_offset = row_idx_offset * in_row_stride; + +#pragma unroll + for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) { + ComputeKernel( + had_frag_i, had_frag_t, + in_sh_ptr + in_row_offset + + (compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)), + local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg); + } + + // Ensure all threads have finished their computation before new data over-writes the shared + // memory. + __syncthreads(); + } + } + } + + const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp; + + if constexpr (kReturnPreRhtAmax) { + unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax); + } + if constexpr (kReturnIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + } + if constexpr (kReturnTransposedAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + } + + ReduceMax( + local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity, + max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr, + output_transpose_amax_ptr, warpid); + + destroy_barriers(mbar, is_master_thread); +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +template +__global__ void HadamardTransformKernel(const T* __restrict__ input, T* __restrict__ output, + T* __restrict__ output_t, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, uint64_t num_input_rows, + uint64_t num_input_cols, float* __restrict__ amax, + float* __restrict__ amax_t, bool inverse_hadamard) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 + static_assert(kHadamardDimension == 16, "Currently only hadamard dimension 16 is supported."); + + // The whole threadblock will share the same smem. + extern __shared__ __align__(16) T smem[]; + + // Each 32 threads process a 16x16 matrix. There is a (y, z) grid of 16x16. + // If y = 4, z = 4, then each threadblock is processing a 4x4 grid of 16x16 matrices. + int32_t tid = threadIdx.x; + int32_t warp_id = threadIdx.y * blockDim.z + threadIdx.z; + int32_t local_bx = threadIdx.y; + int32_t local_by = threadIdx.z; + + // Define the register fragments + uint32_t a_frag[4]; // A matrix fragment + uint32_t b_frag_i[4]; // Transposed Hadamard matrix fragment, used for A @ B(col major) + uint32_t b_frag_t[4]; // Hadamard matrix fragment, used for A.T @ B.T(col major) + uint32_t c_frag[4]; // Result fragment + + // row and col for each thread. 32 threads will work together in 128 chunk to + // load the data from global memory to shared memory. + uint32_t row = tid / (kHadamardDimension * sizeof(T) / sizeof(uint4)); + uint32_t col = tid % (kHadamardDimension * sizeof(T) / sizeof(uint4)); + + uint32_t smem_index = tid; + + uint32_t input_start_col = (blockIdx.x * blockDim.y + local_bx) * kHadamardDimension; + uint32_t input_start_row = (blockIdx.y * blockDim.z + local_by) * kHadamardDimension; + + bool load = (input_start_col < num_input_cols) && (input_start_row < num_input_rows); + if (!load) { + // Out of bound, we are returning early. No thread divergence since the whole warp + // will return early. + return; + } + + uint64_t global_offset = input_start_col + input_start_row * num_input_cols; + uint64_t global_offset_t = + kOutputTrueTransposed ? (input_start_row + input_start_col * num_input_rows) : global_offset; + + T* base_smem = smem + kHadamardDimension * kHadamardDimension * warp_id; + + uint32_t* smem_b32 = reinterpret_cast(base_smem); + uint4* smem_b128 = reinterpret_cast(base_smem); + + // Asynchronously load the data from global memory to shared memory. + const uint4* input_b128 = reinterpret_cast(input + global_offset); + // Each 16x16 chunk is divided into 4 8x8 matrices, we are trying to load each + // 8x8 chunks consecutively into the smem, so we could leverage ldmatrix m8n8x4 + // to load the data in the tensor core swizzled format. + __pipeline_memcpy_async(&smem_b128[smem_index], + &input_b128[row * num_input_cols / (sizeof(uint4) / sizeof(T)) + col], + sizeof(uint4)); + __pipeline_commit(); // Commit the memcpy. Wait when we are in the computation. + + if (inverse_hadamard) { + get_hadamard_matrix_fragment(b_frag_i, random_sign_mask, + b_frag_t, random_sign_mask_t); + } else { + get_hadamard_matrix_fragment( + b_frag_i, random_sign_mask, b_frag_t, random_sign_mask_t); + } + + float local_amax = 0.0; + float local_amax_t = 0.0; + uint32_t local_amax_reg = *reinterpret_cast(&local_amax); + uint32_t local_amax_t_reg = *reinterpret_cast(&local_amax_t); + __pipeline_wait_prior(0); + + __syncwarp(); // ensure all lanes finished their cp.async before reading smem + + // Load the A to a_frag. + if constexpr (kComputeIdentity) { + load_matrix_16x16_from_shared(a_frag[0], a_frag[1], a_frag[2], a_frag[3], smem_b32, + kHadamardDimension); + + // 16x16 @ 16x16 leveraging all threads in the warp. + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2], + b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], local_amax_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnIdentity) { + uint4* output_b128 = reinterpret_cast(output + global_offset); + store_matrix_16x16_to_global(c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_b128, + num_input_cols); + } + } + + if constexpr (kComputeTransposed) { + if (kComputeIdentity) { + matrix_transpose_m8_n8_b16_inplace(a_frag[0]); + matrix_transpose_m8_n8_b16_inplace(a_frag[1]); + matrix_transpose_m8_n8_b16_inplace(a_frag[2]); + matrix_transpose_m8_n8_b16_inplace(a_frag[3]); + } else { + load_matrix_16x16_from_shared(a_frag[0], + a_frag[2], // NOTE: intentional index swapping + a_frag[1], // NOTE: intentional index swapping + a_frag[3], smem_b32, kHadamardDimension); + } + + mma_m16_n16_k16_b16_b16_b16_noacc( + a_frag[0], + // 2,1 is used if we are using movmatrix instruction. + // Thus loading the matrix in 2,1 order will just be normal. + // This is to be compatible with the movmatrix instruction. + a_frag[2], // NOTE: intentional index swapping for transpose purpose. + a_frag[1], // NOTE: intentional index swapping for transpose purpose. + a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2], b_frag_t[3], c_frag[0], c_frag[1], + c_frag[2], c_frag[3], local_amax_t_reg); + + // Store the result to the shared memory in non-transposed order. + if constexpr (kReturnTransposed) { + uint4* output_t_b128 = reinterpret_cast(output_t + global_offset_t); + store_matrix_16x16_to_global( + c_frag[0], c_frag[1], c_frag[2], c_frag[3], output_t_b128, + kOutputTrueTransposed ? num_input_rows : num_input_cols); + } + } + + if constexpr (kUpdateIdentityAmax) { + unpack_max_of_packed_bf16(local_amax_reg, local_amax); + local_amax = warp_reduce_max(local_amax); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax = __shfl_sync(0xFFFFFFFF, local_amax, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax, local_amax); + } + } + if constexpr (kUpdateTransposeAmax) { + unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t); + local_amax_t = warp_reduce_max(local_amax_t); + // broadcast the amax to all threads in a warp from the lane 0 + constexpr int lane_zero = 0; + local_amax_t = __shfl_sync(0xFFFFFFFF, local_amax_t, lane_zero); + // atomic CAS to output memory. + if (tid % kThreadsPerWarp == 0) { + atomicMaxFloat(amax_t, local_amax_t); + } + } +#else + NVTE_DEVICE_ERROR("Kernel is only supported on SM 9.0+."); +#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +} + +} // namespace + +void hadamard_transform(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform); + + // Check tensors + // NOTE (frsun): This is non-intuitive, we are writing the result of + // transposed RHT to the output of rowwise. + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + NVTE_CHECK(output_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be simple tensor, but scaling mode is ", + to_string(output_.scaling_mode), "."); + const SimpleTensor& input = input_.data; + SimpleTensor output; + SimpleTensor& output_t = output_.data; + + // Check requested outputs + const bool return_identity = output.dptr != nullptr; + const bool return_transposed = output_t.dptr != nullptr; + if (!return_identity && !return_transposed) { // Nothing to do/ill-defined behavior. + return; + } + + checkCuDriverContext(stream); + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + using IType = bf16; + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kThreadBlockX = 4; + // Configure 4 is used for Hopper, 8 is used for Blackwell for extra memory bandwidth. + constexpr uint64_t kThreadBlockY = 4; + + uint64_t kNumWarpsPerSM = kThreadBlockX * kThreadBlockY; + + // The shared memory number of bytes required for **the whole threadblock**. + size_t shmem_bytes = kHadamardDimension * kHadamardDimension * sizeof(IType) * kNumWarpsPerSM; + + dim3 block(kThreadsPerWarp, kThreadBlockX, kThreadBlockY); + + dim3 grid(DIVUP(row_length / kHadamardDimension, kThreadBlockX), + DIVUP(num_rows / kHadamardDimension, kThreadBlockY)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed, kReturnTransposed, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + auto kernel = + HadamardTransformKernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_bytes); + + kernel<<>>( + reinterpret_cast(input.dptr), reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), random_sign_mask, random_sign_mask_t, + num_rows, row_length, nullptr, nullptr, false););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +// Kernel that will apply the 16x16 hadamard transform the input and input.T, and then +// get the absolute max value of the result. +void hadamard_transform_amax(const Tensor& input_, Tensor& output_, uint16_t random_sign_mask, + uint16_t random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_amax); +#if CUDA_VERSION >= 12080 + + // Check input tensor + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor& input = input_.data; + + // Check amax tensors + SimpleTensor& output_pre_rht_amax = output_.amax; + SimpleTensor output_identity_amax; + SimpleTensor& output_transpose_amax = output_.columnwise_amax; + + // Check requested outputs + const bool return_pre_rht_amax = output_pre_rht_amax.dptr != nullptr; + const bool return_identity_amax = output_identity_amax.dptr != nullptr; + const bool return_transposed_amax = output_transpose_amax.dptr != nullptr; + if (!return_identity_amax && !return_transposed_amax && + !return_pre_rht_amax) { // Nothing to do/ill-defined behavior. + return; + } + + // Zero out amaxes if needed + ZeroAmaxKernel<<<1, 1, 0, stream>>>(reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + + checkCuDriverContext(stream); + + using IType = bf16; + + const size_t ndim = input.shape.size(); + const size_t row_length = input.shape[ndim - 1]; + size_t num_rows = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + num_rows *= input.shape[i]; + } + + constexpr int kHadamardDimension = 16; + NVTE_CHECK(row_length % kHadamardDimension == 0, + "row_length must be divisible by hadamard_dimension."); + NVTE_CHECK(num_rows % kHadamardDimension == 0, + "num_rows must be divisible by hadamard_dimension"); + + constexpr uint64_t kChunkBlockXSmall = 128; + constexpr uint64_t kChunkBlockYSmall = 128; + constexpr uint64_t kBuffDimX = 64; + constexpr uint64_t kBuffDimY = 64; + + alignas(64) CUtensorMap tensor_map_input{}; + + create_2D_tensor_map( + /*tensorMap=*/tensor_map_input, + /*tensor=*/input, + /*globalY=*/num_rows, + /*globalX=*/row_length, + /*shmemY=*/kBuffDimY, + /*shmemX=*/kBuffDimX, + /*stride_elems=*/row_length, + /*offset_elems=*/0, + /*type_num_bits=*/sizeof(IType) * 8, + /*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B); + + constexpr uint64_t kThreadBlockX = 4; + constexpr uint64_t kThreadBlockY = 1; + constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY; + + dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY); + + dim3 grid(DIVUP(row_length, kChunkBlockXSmall), DIVUP(num_rows, kChunkBlockYSmall)); + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transposed_amax, kReturnTransposedAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity_amax, kReturnIdentityAmax, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_pre_rht_amax, kReturnPreRhtAmax, + + // *2 for ping-pong + size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType); + size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) * + (kChunkBlockYSmall / kBuffDimY); + size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3; + // Add padding in case shmem ptr is not aligned to 128 bytes. + shmem_bytes = (shmem_bytes + 128); + + auto kernel = HadamardAmaxTmaKernel< + IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY, + kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax, + kReturnIdentityAmax, kReturnTransposedAmax>; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + shmem_bytes); + + kernel<<>>( + tensor_map_input, reinterpret_cast(output_pre_rht_amax.dptr), + reinterpret_cast(output_identity_amax.dptr), + reinterpret_cast(output_transpose_amax.dptr), random_sign_mask, + random_sign_mask_t, num_rows, row_length);))); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ", + CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace transformer_engine + +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform); + using namespace transformer_engine; + hadamard_transform(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} + +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_amax); + using namespace transformer_engine; + hadamard_transform_amax(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + static_cast(random_sign_mask), + static_cast(random_sign_mask_t), stream); +} diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu new file mode 100644 index 000000000..12f02dba6 --- /dev/null +++ b/transformer_engine/common/hadamard_transform/hadamard_transform_cast_fusion.cu @@ -0,0 +1,834 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/util/cuda_runtime.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" +#include "cutlass/arch/barrier.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/collective/builders/sm100_common.inl" +#include "cutlass/numeric_conversion.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/helper_cuda.hpp" +#include "cutlass/util/print_error.hpp" + +// clang-format off + +namespace transformer_engine { +namespace detail { +namespace { + +using namespace cute; +using cute::Tensor; // Ensure unqualified Tensor refers to cute::Tensor, not transformer_engine::Tensor + +// calculate the global encode scale factor for a given global amax. +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float kFP8E4M3Max = 448.0f; + constexpr float kFP4E2M1Max = 6.0f; + // If scale is infinity, return max value of float32 + float global_encode_scale = cutlass::minimum_with_nan_propagation{}( + kFP8E4M3Max * kFP4E2M1Max / global_amax, cutlass::platform::numeric_limits::max()); + // If global amax is 0 or infinity, return 1 + return (global_amax == 0.f || global_encode_scale == 0.f) ? 1.f : global_encode_scale; +} + +template +struct SharedStorage { + static constexpr int AccumulatorPipelineStageCount = 16; + using AtomThrShapeMNK = cute::Shape<_1, _1, _1>; + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage; + + alignas(16) AccumulatorPipelineStorage accumulator; + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) cute::uint64_t tma_barrier[1]; + uint32_t tmem_base_ptr; + + struct TensorStorage : cute::aligned_struct<128, _1> { + // cute::array_aligned> smem_A; + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + +}; + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverterBase(cutlass::Array const &input, cutlass::Array const &rbits) { + using result_type = cutlass::Array; + result_type output; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + auto output_ptr = reinterpret_cast(&output); + asm volatile( \ + "{\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n" \ + "cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n" \ + "}" \ + : "=h"(output_ptr[0]), + "=h"(output_ptr[1]) + : "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), + "f"(input[4]), "f"(input[5]), "f"(input[6]), "f"(input[7]), + "r"(rbits[0]), "r"(rbits[1])); + } else { + NVTE_DEVICE_ERROR("FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return output; +} + +CUTLASS_DEVICE +cutlass::Array +StochasticNumericConverter(cutlass::Array const &input, cutlass::Array const *rbits) { + using result_type = cutlass::Array; + result_type output; + cutlass::Array *result_ptr = reinterpret_cast *>(&output); + cutlass::Array const *source_ptr = reinterpret_cast const *>(&input); + cutlass::Array const *rbits_ptr = reinterpret_cast const *>(rbits); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; i++) { + result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]); + } + return output; +} + +template +__global__ static +void +rht_gemm_device(MShape M, NShape N, KShape K, ClusterTileShape cluster_tile, + TA const* A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a, + TB const* B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b, + TC * C, CStride dC, CSmemLayout , + TSFC * SFC, + TiledMMA mma, + float const* global_amax, + const size_t* rng_state) +{ + using namespace cute; + using X = Underscore; + // static constexpr bool kApplyStochasticRounding = true; + using ElementAccumulator = float; + static constexpr int K_PIPE_MAX = size<3>(ASmemLayout{}); + using AtomThrShapeMNK = Shape(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>; + static constexpr uint32_t kTmaTransactionBytes = + cutlass::bits_to_bytes(size(AtomThrShapeMNK{}) * cosize(take<0,3>(ASmemLayout{})) * cute::sizeof_bits_v); + + static constexpr int kTmaRhtTensorTransactionBytes = + cutlass::bits_to_bytes(16 * 16 * cute::sizeof_bits_v); + static constexpr int AccumulatorPipelineStageCount = 16; + + static constexpr int MainloopPipelineStageCount = size<3>(ASmemLayout{}); + using MainloopPipeline = cutlass::PipelineTmaUmmaAsync< + MainloopPipelineStageCount, + Shape<_1,_1,_1>, + AtomThrShapeMNK>; + using MainloopPipelineState = typename MainloopPipeline::PipelineState; + + using TmemAllocator = cute::TMEM::Allocator1Sm; + static constexpr int VectorSize = 16; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + // Preconditions + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + CUTE_STATIC_ASSERT(is_static::value); + + // Represent the full tensors + Tensor mA = tma_load_a.get_tma_tensor(make_shape(M,N)); + Tensor mB = tma_load_b.get_tma_tensor(make_shape(16,16)); + Tensor mC = make_tensor(cute::subbyte_iterator(C), make_shape(M,N), dC); // (M,N) + + auto sfc_shape = make_shape( + M, + make_shape( make_shape(Int<16>{}, _4{}), N / 64 ) + ); + + auto sfc_stride = make_stride( + N / 16, + make_stride( make_stride(_0{}, _1{}), _4{} ) + ); + + auto sfc_layout = make_layout(sfc_shape, sfc_stride); + Tensor mSFC = make_tensor(make_gmem_ptr(SFC), sfc_layout); + + auto cluster_shape = Shape< _1, _1, _1>{}; + + // Get the appropriate blocks for this Cluster + dim3 cluster_coord_in_grid = cluster_id_in_grid(); + + // Total number of k-tiles + const int K_TILE_MAX = min(N, K) / 64; + uint32_t tiles_in_m = (M + size<0>(cluster_tile) - 1) / size<0>(cluster_tile); + uint32_t tiles_in_n = (N + 64 - 1) / 64; + uint32_t linear_tile_idx = blockIdx.x; + uint32_t tile_idx_m = linear_tile_idx % tiles_in_m; + uint32_t tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + + + auto mainloop_tiler = Shape<_128,_16,_64>{}; + auto epilogue_tiler = Shape<_128,_64,_64>{}; + Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_,_, _), Step<_1, X,_1>{}); + Tensor gB_nk = local_tile(mB, cluster_tile, make_coord(_,_, _), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC_mn = local_tile(mC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + Tensor gSFC_mn = local_tile(mSFC, epilogue_tiler, make_coord(_,_, _), Step<_1,_1, X>{}); // (BLK_M,BLK_N) + // Allocate SMEM + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()), sAlayout); // (MMA,MMA_M,MMA_N,PIPE) + Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()), sBlayout); // (MMA,MMA_N,MMA_K,PIPE) + + + // + // MMA: Define C accumulators and A/B partitioning + // + + int block_rank_in_cluster = cute::block_rank_in_cluster(); + ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx + Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k) + + auto mma_epilogue = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + ThrMMA thr_mma_epilogue = mma_epilogue.get_slice(block_rank_in_cluster); + + + using TiledMmaEpilogue = decltype(mma_epilogue); + Tensor tCgA = thr_mma.partition_A(gA_mk); + // Allocate "fragments" -- these are actually umma smem descriptors + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE) + + auto acc_shape_mma = partition_shape_C(TiledMMA{}, take<0,2>(ClusterTileShape{})); + auto acc_shape_epilogue = partition_shape_C(TiledMmaEpilogue{}, take<0,2>(epilogue_tiler)); + + auto bulk_tmem_mma = TiledMMA::make_fragment_C(append(acc_shape_mma, + Int{})); + + auto bulk_tmem_epilogue = TiledMmaEpilogue::make_fragment_C(append(acc_shape_epilogue, + Int{})); + + TmemAllocator tmem_allocator{}; + cutlass::arch::NamedBarrier tmem_allocation_result_barrier(32 + 128, cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier); + + Layout cta_layout_mnk = make_layout(cluster_shape); + Layout cta_layout_vmnk = tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{})); + auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster); + + auto [tAgA, tAsA] = tma_partition(tma_load_a, + get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)), + group_modes<0,3>(tCsA), group_modes<0,3>(tCgA)); + + auto [tBgB, tBsB] = tma_partition(tma_load_b, + get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)), + group_modes<0,3>(tCsB), group_modes<0,3>(tCgB)); + + uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk); + uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk); + + int warp_idx = cutlass::canonical_warp_idx_sync(); + + bool is_mma_warp = (warp_idx == 0); + bool is_dma_warp = (warp_idx == 1); + bool is_epilogue_warp = (warp_idx >= 4 && warp_idx <= 7); + + if (is_epilogue_warp && elect_one_sync()) { + cute::prefetch(raw_pointer_cast(global_amax)); + } + + typename MainloopPipeline::Params mainloop_pipeline_params; + if (is_dma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (is_mma_warp) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp; + mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes; + mainloop_pipeline_params.initializing_warp = 0; + MainloopPipeline mainloop_pipeline(shared_storage.mainloop, + mainloop_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + MainloopPipelineState mainloop_pipe_consumer_state; + MainloopPipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + + + + using AccumulatorPipeline = cutlass::PipelineUmmaAsync; + using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState; + + AccumulatorPipelineState accumulator_pipe_consumer_state; + AccumulatorPipelineState accumulator_pipe_producer_state = cutlass::make_producer_start_state(); + + typename AccumulatorPipeline::Params accumulator_pipeline_params; + if (is_mma_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer; + } + if (is_epilogue_warp) { + accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer; + } + // Only one producer thread arrives on this barrier. + accumulator_pipeline_params.producer_arv_count = 1; + accumulator_pipeline_params.consumer_arv_count = size(AtomThrShapeMNK{}) * 128; + accumulator_pipeline_params.initializing_warp = 1; + AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, + accumulator_pipeline_params, + cluster_shape, + cute::true_type{}, // Perform barrier init + cute::true_type{}); // Delay mask calculation + + if (warp_idx == 2 && elect_one_sync()) { + cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1); + } + __syncthreads(); + using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x; + + if (is_dma_warp) { + if (elect_one_sync()) { + cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0], kTmaRhtTensorTransactionBytes); + copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_,0,0), tBsB(_,0)); + } + cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/); + do { + bool is_first_wave = linear_tile_idx == blockIdx.x; + uint32_t skip_wait = is_first_wave; + auto tAgA_mk = tAgA(_,tile_idx_m,_); + int k_tile = 0; + auto barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + + + CUTE_NO_UNROLL + while (k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n) { + int k_tile_idx_n = tile_idx_n + k_tile; + ++k_tile; + skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount); + mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state, barrier_token); + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state); + int write_stage = mainloop_pipe_producer_state.index(); + ++mainloop_pipe_producer_state; + barrier_token = mainloop_pipeline.producer_try_acquire(mainloop_pipe_producer_state, skip_wait); + if (cute::elect_one_sync()) { + copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_,k_tile_idx_n), tAsA(_,write_stage)); + } + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + mainloop_pipeline.producer_tail(mainloop_pipe_producer_state); + } else if (is_mma_warp) { + mma.accumulate_ = UMMA::ScaleOut::Zero; + + tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns, &shared_storage.tmem_base_ptr); + __syncwarp(); + tmem_allocation_result_barrier.arrive(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_mma.data() = tmem_base_ptr; + + do { + uint32_t skip_wait = K_TILE_MAX <= 0; + auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ) + { + mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token); + int read_stage = mainloop_pipe_consumer_state.index(); + auto tCrA_mk = tCrA(_,_,_,read_stage); + auto tCrB_nk = tCrB(_,_,0,0); + CUTE_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) / 4; ++k_block) + { + accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state); + CUTE_UNROLL + for (int i = 0; i < 4; i++) { + auto accumulators = bulk_tmem_mma(_,_,_,accumulator_pipe_producer_state.index() * 4 + i); + gemm(mma, tCrA_mk(_,_,k_block * 4 + i), tCrB_nk, accumulators); + } + + accumulator_pipeline.producer_commit(accumulator_pipe_producer_state); + ++accumulator_pipe_producer_state; + } + auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state; + ++mainloop_pipe_consumer_state; + ++k_tile; + skip_wait = k_tile >= K_TILE_MAX; + barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait); + mainloop_pipeline.consumer_release(curr_mainloop_pipe_consumer_state); + } + + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + tmem_allocator.release_allocation_lock(); + accumulator_pipeline.producer_tail(accumulator_pipe_producer_state); + tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns); + } else if (is_epilogue_warp) { + const float global_amax_val = *global_amax; + static constexpr int FragmentSize = 256 / sizeof_bits_v; + + tmem_allocation_result_barrier.arrive_and_wait(); + uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr; + bulk_tmem_epilogue.data() = tmem_base_ptr; + int thread_idx = threadIdx.x % 128; + + Tensor tCgC = thr_mma_epilogue.partition_C(gC_mn); // (MMA,MMA_M,MMA_N) // (MMA,MMA_M,MMA_N) + auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_,_,_,_0{})); + auto tiled_r2g = make_tiled_copy_D(Copy_Atom{}, tiled_t2r); + auto thr_t2r = tiled_t2r.get_slice(thread_idx); + auto thr_r2g = tiled_r2g.get_slice(thread_idx); + + // NVFP4 non-E8 recipe constants and global scales + static constexpr float fp4_max = 6.0f; + + const float global_encode_scale = ComputeGlobalEncodeScaleFP4(global_amax_val); + const float global_decode_scale = 1.0f / global_encode_scale; + auto sfd_converter = cutlass::NumericConverter{}; + + do { + for (int k_tile = 0; k_tile < K_TILE_MAX && k_tile + tile_idx_n < tiles_in_n; ++k_tile) { + Tensor tCgC_mn = tCgC(_,_,_,tile_idx_m,tile_idx_n+k_tile); + + Tensor tCgSFC_mn = gSFC_mn(_,_,tile_idx_m,tile_idx_n+k_tile); + accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state); + + auto tCtC = bulk_tmem_epilogue(_,_,_,accumulator_pipe_consumer_state.index()); + Tensor tDtC = thr_t2r.partition_S(tCtC); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDgC = thr_t2r.partition_D(tCgC_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + + Tensor tTR_rAcc = make_tensor(shape(tDgC)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N) + Tensor tDrC = make_tensor(shape(tDgC)); + Tensor tTR_rAcc_frag = recast>(coalesce(tTR_rAcc)); + Tensor tDrC_frag = recast>(coalesce(tDrC)); + + Tensor src = thr_r2g.retile_S(tDrC); + Tensor dst = thr_r2g.retile_D(tDgC); + + Tensor tCgSFC = make_tensor(tCgSFC_mn.data(), make_layout( + make_shape(shape(tCgSFC_mn), Int<1>{}, Int<1>{}), + make_stride(stride(tCgSFC_mn), Int<0>{}, Int<0>{}) + )); + + Tensor tDgSFC = filter(thr_t2r.partition_D(tCgSFC)); + Tensor tDrSFC = make_tensor(shape(tDgSFC)); + + static constexpr int NumVecs = size(tDgC) / VectorSize; + Tensor tC_rRowSFD_frg = recast>(tDrSFC); + + cutlass::maximum_absolute_value_reduction, true> amax_reduction; + cutlass::Array vec_maxs; + cutlass::Array pvscales; + // TMEM_LOAD + copy(tiled_t2r, tDtC, tTR_rAcc); + cutlass::arch::fence_view_async_tmem_load(); + + accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state); + + ++accumulator_pipe_consumer_state; + + // Cast data from FP32 to BF16 to FP32. + auto convert_accum_to_bf16 = cutlass::NumericArrayConverter{}; + auto convert_bf16_to_accum = cutlass::NumericArrayConverter{}; + tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{}))); + + auto compute_frgs = reinterpret_cast *>(tTR_rAcc_frag.data()); + auto output_frgs = reinterpret_cast *>(tDrC_frag.data()); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]); + } + + pvscales = cutlass::divides>{}(vec_maxs, fp4_max); + pvscales = cutlass::multiplies>{}(pvscales, global_encode_scale); + auto pvscales_cvted = cutlass::NumericArrayConverter{}(pvscales); + + tC_rRowSFD_frg(_0{}) = pvscales_cvted; + auto qpvscale_ups = cutlass::NumericArrayConverter{}(tC_rRowSFD_frg(_0{})); + auto qpvscale_scaled = cutlass::multiplies>{}(qpvscale_ups, global_decode_scale); + auto acc_scales = cutlass::divides>{}(1.0, qpvscale_scaled); + + // Initialize RNG for tile + const size_t rng_sequence + = thread_idx + k_tile * 256 + linear_tile_idx * K_TILE_MAX * 256; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = uint4{0, 0, 0, 0}; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < NumVecs; v++) { + auto acc_scale = cutlass::minimum_with_nan_propagation{}(acc_scales[v], cutlass::platform::numeric_limits::max()); + // auto acc_scale = acc_scales[v]; + if constexpr (kEnableStochasticRounding) { + random_uint4 = rng.generate4(); + output_frgs[v] = StochasticNumericConverter( + cutlass::multiplies>{}( + compute_frgs[v], + acc_scale + ), + reinterpret_cast*>(&random_uint4)); + } else { + output_frgs[v] = cutlass::NumericArrayConverter{}(cutlass::multiplies>{}(compute_frgs[v], acc_scale)); + } + } + + copy(tiled_r2g, src, dst); + + copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFC, tDgSFC); + + } + linear_tile_idx += gridDim.x; + tile_idx_m = linear_tile_idx % tiles_in_m; + tile_idx_n = (linear_tile_idx / tiles_in_m) * K_TILE_MAX; + } while (tile_idx_m < tiles_in_m && tile_idx_n < tiles_in_n); + } +} + +// this function computes RHT-GEMM for +// A: m x n: col-major +// B: 16 x 16: row-major +// C: m x n: row-major +// SFC: m x (n/16): row-major +template +void +rht_gemm_ntt_w_sfc(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 2048) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = static_cast(m); + auto N = static_cast(n); + + // Define strides (mixed) + auto dA = make_stride(Int<1>{}, m); // (dM,dK) + auto dB = make_stride(Int<1>{}, 16); // (dN,dK) + auto dC = make_stride(n, Int<1>{}); // (dM,dN) + + auto cga_shape = Shape< _1, _1, _1>{}; + auto cga_tile_shape = Shape<_128,_16,_16>{}; + auto cluster_tile_mainloop = Shape<_128,_16,_64>{}; + + // Construct the MMA + auto mma = make_tiled_mma(SM100_MMA_F16BF16_SS{}, + Layout>{}); + + // MMA in CGA Layout XXX: Need to generalize synchro? {$nv-release-never} + + // Assert that the TiledMMA uses all CTAs in the CGA. + CUTE_STATIC_ASSERT_V(size(cga_shape) == size(mma)); + CUTE_STATIC_ASSERT_V(evenly_divides(cga_tile_shape, tile_shape(mma))); + + // Determine the A and B shapes + auto mma_shape_B = partition_shape_B(mma, make_shape(size<1>(cga_tile_shape), size<2>(cga_tile_shape))); + + using TiledMma = decltype(mma); + using AtomThrID = typename TiledMma::AtomThrID; + + using SmemShape_M = decltype(shape_div(shape<0>(cga_tile_shape), shape_div(shape<0>(cga_tile_shape), size<0>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_N = decltype(shape_div(shape<1>(cga_tile_shape), shape_div(shape<1>(cga_tile_shape), size<1>(cga_tile_shape) / size(AtomThrID{})))); + using SmemShape_K = decltype(cute::get<2>(cga_tile_shape)); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TB, SmemShape_N, SmemShape_K>()); + + auto mma_shape_A = partition_shape_A(mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop))); + using SmemShape_M_A = decltype(shape_div(shape<0>(cluster_tile_mainloop), shape_div(shape<0>(cluster_tile_mainloop), size<0>(cluster_tile_mainloop) / size(AtomThrID{})))); + using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop)); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector< + cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>()); + + // Define the smem layouts (static) + // Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory + constexpr int kBlackwellSmemSize = 232448; // 232KB in bytes + constexpr int kBytesPerStage = cute::size(mma_shape_A) * sizeof(TA) + cute::size(mma_shape_B) * sizeof(TB); + constexpr int kReservedBytes = 256; // Reserve for barriers and other uses + constexpr int kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage; + auto sP = Int{}; // SMEM pipelines + auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP)); // (MMA,MMA_M,MMA_K,PIPE) + auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{}, append(mma_shape_B, sP)); // (MMA,MMA_N,MMA_K,PIPE) + auto sC = Layout<_1>{}; // XXX Dummy + + // Create GMEM tensors + Tensor tensorA = make_tensor(A, make_layout(make_shape(M,N), dA)); // (M,N) + Tensor tensorB = make_tensor(B, make_layout(make_shape(16,16), dB)); // (16,16) + + // Create the TiledCopy + + auto tma_load_a = make_tma_copy_A_sm100( + SM90_TMA_LOAD{}, + tensorA, + sA(_,_,_,0), + cluster_tile_mainloop, + mma); + auto tma_load_b = make_tma_copy_B_sm100( + SM90_TMA_LOAD{}, + tensorB, + sB(_,_,_,0), + cga_tile_shape, + mma); + + // Assert checks on tile sizes -- no predication + NVTE_CHECK(M % size<0>(cga_tile_shape) == 0, + "Inner dimension must be divisible by ", static_cast(size<0>(cga_tile_shape)), " but got ", M, "."); + NVTE_CHECK(N % (4 * size<1>(cga_tile_shape)) == 0, + "Outer dimension must be divisible by ", 4 * static_cast(size<1>(cga_tile_shape)), + " but got ", N, "."); + + uint32_t tiles = size(ceil_div(M, get<0>(cga_tile_shape))) * size(ceil_div(N, k_tile_size)); + + tiles = (tiles < sm_count) ? tiles : sm_count; + + dim3 dimBlock(256); + dim3 dimCluster(size<0>(cga_shape), size<1>(cga_shape), size<2>(cga_shape)); + dim3 dimGrid(tiles, 1, 1); + + int smem_size = sizeof(SharedStorage); + auto* kernel_ptr = &rht_gemm_device< + decltype(M), decltype(N), decltype(k_tile_size), decltype(cga_tile_shape), + TA, decltype(dA), decltype(sA), decltype(tma_load_a), + TB, decltype(dB), decltype(sB), decltype(tma_load_b), + TC, decltype(dC), decltype(sC), + TSFC, + decltype(mma), + kEnableStochasticRounding>; + + bool status = cudaFuncSetAttribute(*kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (status != cudaSuccess) { + std::cerr << "Error: Failed to set Shared Memory size." << std::endl; + return; + } + (*kernel_ptr) + <<< dimGrid, dimBlock, smem_size, stream >>> + (M, N, k_tile_size, cga_tile_shape, + A, dA, sA, tma_load_a, + B, dB, sB, tma_load_b, + C, dC, sC, + SFC, + mma, global_amax, + rng_state); +} + +// this function is used to wrap the rht_gemm_ntt_w_sfc function +//to transpose the input tensor A +template +void +rht_gemm_ttt_wrapper(int m, int n, + TA const* A, + TB const* B, + TC * C, + TSFC * SFC, + float const* global_amax, + const size_t* rng_state, + uint32_t sm_count, + cudaStream_t stream, + int k_tile_size = 1024) +{ + // in addition to transpose the input tensor A + // we also need to reshape m, n to at best + // ultilize as many SMs as possible while keeping + // a relatively large contiguous dimension. + // for example, after swapping m, n for transpose purposes, + // the input / output tensor shapes for RHT-GEMM are: + // A: n x m: col-major + // B: 16 x 16: row-major + // C: n x m: row-major + // SFC: n x (m/16): row-major + rht_gemm_ntt_w_sfc( + n, m, + A, B, C, + SFC, global_amax, + rng_state, + sm_count, stream, + k_tile_size); +} + +} // namespace +} // namespace detail + +// clang-format on + +void hadamard_transform_cast_fusion_columnwise(const Tensor &input_, Tensor &output_, + const Tensor &hadamard_matrix_, + QuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(hadamard_transform_cast_fusion_columnwise); + + // Check input and output tensors + NVTE_CHECK(input_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be BF16 tensor, but scaling mode is ", + to_string(input_.scaling_mode), "."); + NVTE_CHECK(input_.dtype() == transformer_engine::DType::kBFloat16, + "Input tensor must be BF16 tensor, but dtype is ", to_string(input_.dtype()), "."); + NVTE_CHECK(input_.dim() >= 2, "Input must be a 2D tensor."); + const SimpleTensor &input = input_.data; + SimpleTensor &global_amax = output_.amax; + SimpleTensor &output_t = output_.data; + SimpleTensor &scale_inv_t = output_.scale_inv; + + // Stochastic rounding config + const bool use_stochastic_rounding = quant_config.stochastic_rounding; + const size_t *rng_state = nullptr; + if (quant_config.rng_state != nullptr) { + Tensor &rng_state_tensor = *convertNVTETensor(quant_config.rng_state); + NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_tensor.data.dptr); + } + + // Template arguments + using TA = cute::bfloat16_t; + using TB = cute::bfloat16_t; + using TC = cutlass::float_e2m1_t; + using TSFC = cutlass::float_ue4m3_t; + + checkCuDriverContext(stream); + + // Check Hadamard matrix + constexpr int kHadamardDimension = 16; + NVTE_CHECK(hadamard_matrix_.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Hadamard matrix must be BF16 tensor, but scaling mode is ", + to_string(hadamard_matrix_.scaling_mode), "."); + NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16, + "Hadamard matrix must be BF16 tensor, but dtype is ", + to_string(hadamard_matrix_.dtype()), "."); + const SimpleTensor &hadamard_matrix = hadamard_matrix_.data; + NVTE_CHECK( + (hadamard_matrix_.shape() == std::vector{kHadamardDimension, kHadamardDimension}), + "Hadamard matrix must have shape=", + std::vector{kHadamardDimension, kHadamardDimension}, + ", but got shape=", hadamard_matrix_.shape(), "."); + const size_t hadamard_dimension = hadamard_matrix.shape[0]; + + const size_t ndim = input.shape.size(); + const size_t n = input.shape[ndim - 1]; + size_t m = 1; + for (size_t i = 0; i < ndim - 1; ++i) { + m *= input.shape[i]; + } + + auto sm_count = transformer_engine::cuda::sm_count(); + + NVTE_CHECK(n % hadamard_dimension == 0, "row_length must be divisible by hadamard_dimension."); + + NVTE_CHECK(m % hadamard_dimension == 0, "num_rows must be divisible by hadamard_dimension"); + + int k_tile_size = 1024; + + if (m == 8192 && n == 5120) { + k_tile_size = 512; + } else if (m == 8192 && n == 10240) { + k_tile_size = 1024; + } else if (m == 8192 && n == 2560) { + k_tile_size = 1280; + } else if (m == 8192 && n == 11328) { + k_tile_size = 1024; + } else if (m == 8192 && n == 512) { + k_tile_size = 256; + } else if (m == 8192 && n == 3584) { + k_tile_size = 512; + } else if (m == 11328 && n == 8192) { + k_tile_size = 1024; + } else if (m == 5120 && n == 8192) { + k_tile_size = 512; + } else if (m == 10240 && n == 8192) { + k_tile_size = 1024; + } else if (m == 2560 && n == 8192) { + k_tile_size = 1280; + } else if (m == 512 && n == 8192) { + k_tile_size = 256; + } else if (m == 3584 && n == 8192) { + k_tile_size = 512; + } else if (m < 1024 || n < 1024) { + k_tile_size = 512; + } + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kUseStochasticRounding, + detail::rht_gemm_ttt_wrapper( + /*m=*/m, + /*n=*/n, + /*A=*/reinterpret_cast(input.dptr), + /*B=*/reinterpret_cast(hadamard_matrix.dptr), + /*C=*/reinterpret_cast(output_t.dptr), + /*SFC=*/reinterpret_cast(scale_inv_t.dptr), + /*global_amax=*/reinterpret_cast(global_amax.dptr), + /*rng_state=*/rng_state, + /*sm_count=*/sm_count, + /*stream=*/stream, + /*k_tile_size=*/k_tile_size);); +} + +} // namespace transformer_engine + +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { + NVTE_API_CALL(nvte_hadamard_transform_cast_fusion_columnwise); + using namespace transformer_engine; + QuantizationConfig quant_config_cpp; + if (quant_config != nullptr) { + quant_config_cpp = *reinterpret_cast(quant_config); + } + hadamard_transform_cast_fusion_columnwise( + *convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 49029ed58..4e4808858 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -39,6 +39,7 @@ enum class NVTE_Activation_Type { QGEGLU, SRELU, SREGLU, + CLAMPED_SWIGLU }; /*! \brief Computes the GeLU activation of the input. @@ -173,6 +174,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); */ void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gated Swish activation of the input used in GPT OSS. + * + * See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This Gated activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes Act(input[N, :H]) x input[N, H:] + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha, + cudaStream_t stream); + /*! \brief Computes the gated ReLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. @@ -230,6 +251,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS. + * + * https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250 + * This activation has two differences compared to the original SwiGLU + * 1. Both gate and pre-activations are clipped based on parameter limit. + * 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired + * by original GELU paper https://arxiv.org/pdf/1606.08415 + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] limit Clipping limits for gate and pre-activation. + * \param[in] alpha Scaling factor for the sigmoid function used in the activation. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + float limit, float alpha, cudaStream_t stream); + /*! \brief Computes the gated ReLU activation gradient. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 4d65e26ce..cffc411a0 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -67,6 +67,11 @@ class CommOverlapCore { std::vector _stream_compute; cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_comm, _comm_launch_event; + private: + void initialize(int tp_size, int num_splits, int num_max_streams, int comm_cga_size, + int gemm_priority, int comm_priority, int num_comm_sm, bool set_sm_margin, + bool use_ce, bool atomic_gemm); + public: CommOverlapCore() {} // dummy constructor for exposing type to Python @@ -78,17 +83,26 @@ class CommOverlapCore { virtual ~CommOverlapCore(); + void *get_ubuf_dptr() { return _ubuf.dptr(); } + void set_ubuf_scale_inv(float *scale_inv) { _ubuf_scale_inv = scale_inv; _ubuf_scale_inv_initialized = true; } + virtual void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) { + NVTE_ERROR("Operation is not implemented."); + } + TensorWrapper get_tensor_chunk(const TensorWrapper &source, size_t offset, const std::vector &shape); TensorWrapper get_buffer_chunk_like(const TensorWrapper &source, size_t offset, const std::vector &shape); + int get_tp_size() { return _tp_size; } + bool is_atomic_gemm() { return _atomic_gemm; } bool is_p2p_overlap() { return _is_p2p; } @@ -148,6 +162,10 @@ class CommOverlapBase : public CommOverlapCore { cudaStream_t _stream_comm; cudaEvent_t _start_d2dcopy; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + bool rs_overlap_first_gemm); + public: CommOverlapBase() {} // dummy constructor for exposing type to Python @@ -224,6 +242,10 @@ class CommOverlapP2PBase : public CommOverlapCore { cudaStream_t _stream_recv; cudaEvent_t _stop_send, _stop_recv; + private: + void initialize(const std::vector &buffer_shape, DType buffer_dtype, + CommOverlapType comm_type, bool aggregate); + public: CommOverlapP2PBase() {} // dummy constructor for exposing type to Python @@ -237,6 +259,9 @@ class CommOverlapP2PBase : public CommOverlapCore { virtual ~CommOverlapP2PBase(); + void copy_into_buffer(cudaStream_t stream, const TensorWrapper &source, bool local_chunk, + bool rowwise = true) override; + TensorWrapper get_buffer_chunk_by_id(const TensorWrapper &source, size_t buffer_id); void bulk_overlap(const TensorWrapper &A, bool transa, const TensorWrapper &B, bool transb, diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 851032e04..158d8ea5d 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -126,6 +126,24 @@ enum NVTE_Mask_Type { NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK = 5, }; +/*! \enum NVTE_Softmax_Type + * \brief Attention softmax types as described in + * Efficient Streaming Language Models with Attention Sinks (https://arxiv.org/pdf/2309.17453v3). + * For a given attention score S = Q*K^T, different softmax types perform different operations on S, + * NVTE_VANILLA_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + * NVTE_OFF_BY_ONE_SOFTMAX: S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + * NVTE_LEARNABLE_SOFTMAX: S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + * where alpha is a learnable parameter in shape [H]. + */ +enum NVTE_Softmax_Type { + /*! Vanilla softmax */ + NVTE_VANILLA_SOFTMAX = 0, + /*! Off-by-one softmax */ + NVTE_OFF_BY_ONE_SOFTMAX = 1, + /*! Learnable softmax */ + NVTE_LEARNABLE_SOFTMAX = 2, +}; + /*! \enum NVTE_Fused_Attn_Backend * \brief Fused attention backends */ @@ -185,29 +203,35 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); /*! \brief Get fused attention backend based on input parameters. * - * \param[in] is_training Whether the model is in training mode. - * \param[in] q_dtype The data type of Tensor Q. - * \param[in] kv_dtype The data type of Tensors K, V. - * \param[in] qkv_layout The layout of Tensors Q, K, V. - * \param[in] bias_type The attention bias type. - * \param[in] attn_mask_type The attention mask type. - * \param[in] dropout The dropout probability. - * \param[in] num_attn_heads The number of heads in Q. - * \param[in] num_gqa_groups The number of heads in K, V. - * \param[in] max_seqlen_q The sequence length of Q. - * \param[in] max_seqlen_kv The sequence length of K, V. - * \param[in] head_dim_qk The head dimension of Q, K. - * \param[in] head_dim_v The head dimension of V. - * \param[in] window_size_left Sliding window size (the left half). - * \param[in] window_size_right Sliding window size (the right half). + * \param[in] is_training Whether the model is in training mode. + * \param[in] q_dtype The data type of Tensor Q. + * \param[in] kv_dtype The data type of Tensors K, V. + * \param[in] qkv_layout The layout of Tensors Q, K, V. + * \param[in] bias_type The attention bias type. + * \param[in] attn_mask_type The attention mask type. + * \param[in] softmax_type The attention softmax type. + * \param[in] dropout The dropout probability. + * \param[in] num_attn_heads The number of heads in Q. + * \param[in] num_gqa_groups The number of heads in K, V. + * \param[in] max_seqlen_q The sequence length of Q. + * \param[in] max_seqlen_kv The sequence length of K, V. + * \param[in] head_dim_qk The head dimension of Q, K. + * \param[in] head_dim_v The head dimension of V. + * \param[in] window_size_left Sliding window size (the left half). + * \param[in] window_size_right Sliding window size (the right half). + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. */ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( bool is_training, NVTEDType q_dtype, NVTEDType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph); /*! \brief Compute dot product attention with packed QKV input. + * + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -243,6 +267,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * * \param[in] QKV The QKV tensor in packed format, H3D or 3HD. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -253,27 +278,36 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, NVTETensor S, - NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, - const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, - const NVTETensor rng_state, size_t max_seqlen, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, +[[deprecated( + "nvte_fused_attn_fwd_qkvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] +void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, + NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens, + const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, + size_t max_seqlen, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, + NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed QKV input. * + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * Support Matrix for ROCm AOTriton: \verbatim | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | @@ -309,6 +343,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * e.g. M, ZInv, rng_state. * \param[out] dQKV The gradient of the QKV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens Cumulative sequence lengths, [batch_size + 1]. * \param[in] cu_seqlens_padded Cumulative sequence offsets for QKV, [batch_size + 1]. * \param[in] max_seqlen Max sequence length used for computing, @@ -318,23 +353,29 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias, * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, - const NVTETensor S, NVTETensor dP, - const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, - NVTETensor dBias, const NVTETensor cu_seqlens, - const NVTETensor cu_seqlens_padded, size_t max_seqlen, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, - bool deterministic, NVTETensor workspace, cudaStream_t stream); +[[deprecated( + "nvte_fused_attn_bwd_qkvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] +void nvte_fused_attn_bwd_qkvpacked( + const NVTETensor QKV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, + NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQKV, NVTETensor dBias, + NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded, + size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph, + NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with packed KV input. + * + * \deprecated Please use `nvte_fused_attn_fwd` with separate Q, K, V tensors instead. * * Computes: * - P = Q * Transpose(K) + Bias @@ -371,6 +412,7 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] Q The Q tensor, in HD layouts. * \param[in] KV The KV tensor, in 2HD or H2D layouts. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -387,28 +429,37 @@ void nvte_fused_attn_bwd_qkvpacked(const NVTETensor QKV, const NVTETensor O, con * \param[in] max_seqlen_kv Max sequence length used for computing for KV. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_fwd_kvpacked() is deprecated. Please use nvte_fused_attn_fwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_fwd_kvpacked( - const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); + const NVTETensor Q, const NVTETensor KV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, + NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, + const NVTETensor page_table_v, const NVTETensor rng_state, size_t max_seqlen_q, + size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with packed KV input. * + * \deprecated Please use `nvte_fused_attn_bwd` with separate Q, K, V tensors instead. * Support Matrix for ROCm AOTriton: \verbatim | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | @@ -446,6 +497,7 @@ void nvte_fused_attn_fwd_kvpacked( * \param[out] dQ The gradient of the Q tensor. * \param[out] dKV The gradient of the KV tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for KV, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -459,20 +511,26 @@ void nvte_fused_attn_fwd_kvpacked( * \param[in] qkv_layout QKV tensor's layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ +[[deprecated( + "nvte_fused_attn_bwd_kvpacked() is deprecated. Please use nvte_fused_attn_bwd() with separate " + "Q, K, V tensors instead.")]] void nvte_fused_attn_bwd_kvpacked( const NVTETensor Q, const NVTETensor KV, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, - NVTETensor dKV, NVTETensor dBias, const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, - const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, - size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, - NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, bool deterministic, NVTETensor workspace, + NVTETensor dKV, NVTETensor dBias, NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens_q, + const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, + float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, + int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute dot product attention with separate Q, K and V. @@ -518,6 +576,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] K The K tensor. * \param[in] V The V tensor. * \param[in] Bias The Bias tensor. + * \param[in] SoftmaxOffset The SoftmaxOffset tensor. * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, @@ -534,27 +593,29 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. + * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ -void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, - const NVTETensor Bias, NVTETensor S, NVTETensor O, - NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, - const NVTETensor cu_seqlens_kv_padded, const NVTETensor page_table_k, - const NVTETensor page_table_v, const NVTETensor rng_state, - size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, - float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, - cudaStream_t stream); +void nvte_fused_attn_fwd( + const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor Bias, + const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, + const NVTETensor page_table_k, const NVTETensor page_table_v, const NVTETensor rng_state, + size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, bool return_max_logit, + bool cuda_graph, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, NVTETensor workspace, cudaStream_t stream); /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * @@ -602,6 +663,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[out] dK The gradient of the K tensor. * \param[out] dV The gradient of the V tensor. * \param[out] dBias The gradient of the Bias tensor. + * \param[out] dSoftmaxOffset The gradient of the SoftmaxOffset tensor. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -615,23 +677,26 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] qkv_layout QKV tensors' layout. * \param[in] bias_type Bias type. * \param[in] attn_mask_type Attention mask type. + * \param[in] softmax_type Attention softmax type. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). * \param[in] deterministic Whether to execute with deterministic behaviours. + * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] workspace Workspace tensor. * \param[in] stream CUDA stream used for this operation. */ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETensor V, const NVTETensor O, const NVTETensor dO, const NVTETensor S, NVTETensor dP, const NVTETensorPack *Aux_CTX_Tensors, NVTETensor dQ, NVTETensor dK, - NVTETensor dV, NVTETensor dBias, const NVTETensor cu_seqlens_q, - const NVTETensor cu_seqlens_kv, const NVTETensor cu_seqlens_q_padded, + NVTETensor dV, NVTETensor dBias, NVTETensor dSoftmaxOffset, + const NVTETensor cu_seqlens_q, const NVTETensor cu_seqlens_kv, + const NVTETensor cu_seqlens_q_padded, const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, int64_t window_size_left, - int64_t window_size_right, bool deterministic, NVTETensor workspace, - cudaStream_t stream); + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + int64_t window_size_left, int64_t window_size_right, bool deterministic, + bool cuda_graph, NVTETensor workspace, cudaStream_t stream); /*! \brief Update the RNG state with the seed and calculated offset. * diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 610868f93..19047f463 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] qkv_format QKV format. * \param[in] interleaved Whether to use interleaved rotary position embedding. @@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream); + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream); /*! \brief Apply rotary positional embedding to the combined QKV input tensor. * diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 58c0a1f96..dd312726a 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -17,9 +17,76 @@ #ifdef __cplusplus extern "C" { -#endif +#endif // __cplusplus -/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. +/*! \brief Configuration for matrix multiplication. */ +typedef void *NVTEMatmulConfig; + +/*! \enum NVTEMatmulConfigAttribute + * \brief Type of option for matrix multiplication. + */ +enum NVTEMatmulConfigAttribute { + /*! Bias tensor + * + * If provided, the bias tensor is applied in the GEMM epilogue. + */ + kNVTEMatmulConfigBiasTensor = 0, + /*! Bias gradient tensor + * + * If provided, the bias gradient tensor will be filled in the GEMM epilogue. + */ + kNVTEMatmulConfigDBiasTensor = 1, + /*! Whether to compute GELU in GEMM epilogue. */ + kNVTEMatmulConfigWithGELUEpilogue = 2, + /*! Whether to compute GELU backward in GEMM epilogue. */ + kNVTEMatmulConfigWithDGELUEpilogue = 3, + /*! Auxilliary tensor for GEMM epilogue. + * + * For GELU, this will be filled with the GELU input. For GELU + * backward, this is expected to already be filled with the GELU + * input. + */ + kNVTEMatmulConfigEpilogueAuxTensor = 4, + /*! Whether to use split accumulator for FP8 GEMM. */ + kNVTEMatmulConfigUseSplitAccumulator = 5, + /*! Number of streaming multiprocessors to use in GEMM kernel. */ + kNVTEMatmulConfigSMCount = 6, + kNVTEMatmulConfigNumAttributes +}; + +/*! \brief Create a matrix multiplication configuration. */ +NVTEMatmulConfig nvte_create_matmul_config(); + +/*! \brief Query an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to write option value. Ignored if + * NULL. + * \param[in] size_in_bytes Size of buf. + * \param[out] size_written Number of bytes that have been written to + * buf. If buf is NULL, then the number of + * bytes that would have been written. + */ +void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + void *buf, size_t size_in_bytes, size_t *size_written); + +/*! \brief Set an option in matrix multiplication configuration. + * + * \param[in] config Matrix multiplication configuration. + * \param[in] attr Option type. + * \param[out] buf Memory address to read option value. + * \param[in] size_in_bytes Size of buf. + */ +void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, + const void *buf, size_t size_in_bytes); + +/*! \brief Destroy a matrix multiplication configuration. */ +void nvte_destroy_matmul_config(NVTEMatmulConfig config); + +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations (deprecated). + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -46,8 +113,31 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations. + * + * Computes: + * - `D = alpha * op(A) * op(B) + beta * C` + * + * \param[in] transa Whether to transpose A matrix. + * \param[in] transb Whether to transpose B matrix. + * \param[in] alpha Scaling factor applied to matmul output. + * \param[in] A A matrix. + * \param[in] B B matrix. + * \param[in] beta Scaling factor applied to C matrix. + * \param[in] C C matrix. + * \param[out] D Output matrix. + * \param[in] workspace Workspace tensor. + * \param[in] config Additional configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm_v2(int transa, int transb, const float *alpha, const NVTETensor A, + const NVTETensor B, const float *beta, const NVTETensor C, NVTETensor D, + NVTETensor workspace, NVTEMatmulConfig config, cudaStream_t stream); + /*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, - * allowing for using a scaling factor for the GEMM result and the accumulation input + * allowing for using a scaling factor for the GEMM result and the accumulation input (deprecated) + * + * This has been deprecated in favor of nvte_cublas_gemm_v2. * * Computes: * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors @@ -135,14 +225,16 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, - bool transa, bool transb, bool grad, NVTETensor* workspace, +void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); #ifdef __cplusplus } // extern "C" -#endif +#endif // __cplusplus + +#ifdef __cplusplus /*! \namespace transformer_engine */ @@ -157,6 +249,89 @@ namespace transformer_engine { void nvte_cublas_handle_init(); #endif +/*! \struct MatmulConfigWrapper + * \brief C++ wrapper for NVTEMatmulConfig. + */ +class MatmulConfigWrapper { + public: + MatmulConfigWrapper() : config_{nvte_create_matmul_config()} {} + + MatmulConfigWrapper(const MatmulConfigWrapper &) = delete; + MatmulConfigWrapper &operator=(const MatmulConfigWrapper &) = delete; + + MatmulConfigWrapper(MatmulConfigWrapper &&other) : config_{other.config_} { + other.config_ = nullptr; + } + MatmulConfigWrapper &operator=(MatmulConfigWrapper &&other) { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + } + config_ = other.config_; + other.config_ = nullptr; + return *this; + } + + ~MatmulConfigWrapper() { + if (config_ != nullptr) { + nvte_destroy_matmul_config(config_); + config_ = nullptr; + } + } + + /*! \brief Get the underlying NVTEMatmulConfig. + * + * \return NVTEMatmulConfig held by this MatmulConfigWrapper. + */ + operator NVTEMatmulConfig() const noexcept { return config_; } + + /*! \brief Set bias tensor. */ + void set_bias_tensor(NVTETensor bias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigBiasTensor, &bias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set bias gradient tensor. */ + void set_dbias_tensor(NVTETensor dbias_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigDBiasTensor, &dbias_tensor, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to compute GELU in GEMM epilogue. */ + void set_with_gelu_epilogue(bool with_gelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithGELUEpilogue, + &with_gelu_epilogue, sizeof(bool)); + } + + /*! \brief Set whether to compute GELU backward in GEMM epilogue. */ + void set_with_dgelu_epilogue(bool with_dgelu_epilogue) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigWithDGELUEpilogue, + &with_dgelu_epilogue, sizeof(bool)); + } + + /*! \brief Set auxilliary tensor for GEMM epilogue. */ + void set_epilogue_aux_tensor(NVTETensor epilogue_aux_tensor) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigEpilogueAuxTensor, + &epilogue_aux_tensor, sizeof(NVTETensor)); + } + + /*! \brief Set whether to use split accumulator for FP8 GEMM. */ + void set_use_split_accumulator(bool use_split_accumulator) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigUseSplitAccumulator, + &use_split_accumulator, sizeof(bool)); + } + + /*! \brief Set number of streaming multiprocessors to use in GEMM kernel. */ + void set_sm_count(int sm_count) { + nvte_set_matmul_config_attribute(config_, kNVTEMatmulConfigSMCount, &sm_count, sizeof(int)); + } + + private: + /*! \brief Wrapped NVTEMatmulConfig. */ + NVTEMatmulConfig config_ = nullptr; +}; + } // namespace transformer_engine +#endif // __cplusplus + #endif // TRANSFORMER_ENGINE_GEMM_H_ diff --git a/transformer_engine/common/include/transformer_engine/hadamard_transform.h b/transformer_engine/common/include/transformer_engine/hadamard_transform.h new file mode 100644 index 000000000..73edf23a3 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/hadamard_transform.h @@ -0,0 +1,74 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file hadamard_transform.h + * \brief Functions for Hadamard transforms. + */ + +#ifndef TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ +#define TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ + +#ifndef __HIP_PLATFORM_AMD__ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Perform a randomized Hadamard transform on the input tensor. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the absolute maximum reduction on the input tensor with/without + * randomized hadamard transform. The rowwise result is the absolute maximum + * of the input tensor. The columnwise result is the absolute maximum of the + * input tensor transposed and applied randomized hadamard transformation. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] random_sign_mask 16-bit sign mask. + * \param[in] random_sign_mask_t 16-bit sign mask. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_amax(const NVTETensor input, NVTETensor output, int random_sign_mask, + int random_sign_mask_t, cudaStream_t stream); + +/*! \brief Perform the columnwise hadamard transform cast fusion. + * + * This function is experimental and the API is not stable. + * + * \param[in] input Input tensor to apply Hadamard transform. + * \param[in,out] output Output tensor. + * \param[in] hadamard_matrix Hadamard matrix. + * \param[in] quant_config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_hadamard_transform_cast_fusion_columnwise(const NVTETensor input, NVTETensor output, + const NVTETensor hadamard_matrix, + const NVTEQuantizationConfig quant_config, + cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif + +#endif // TRANSFORMER_ENGINE_HADAMARD_TRANSFORM_H_ diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 89515108a..a5867276f 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -151,6 +151,10 @@ void nvte_fp8_block_scaling_partial_cast(const NVTETensor inp, NVTETensor out, size_t start_offset, size_t block_len, const NVTEDType out_dtype, cudaStream_t stream); +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index 079feb4a7..624e71d1e 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, const size_t num_tensors, cudaStream_t stream); +/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM + * + * \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv. + * \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it + * not natively supported by cublasLt on architectures other than Hopper. + + * Requirements: + * - input is an FP8 block scaling tensor + * - input has rowwise usage + * - input.scale_inv is in GEMM_READY format + * - output is an MXFP8 tensor + * - output has rowwise usage + * - output.scale_inv has appropriate shape + * */ +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 70f90fa76..044e021e6 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -74,6 +74,7 @@ enum NVTETensorParam { kNVTEAmax = 3, /*!< Amax tensor */ kNVTERowwiseScaleInv = 4, /*!< Scale inverse tensor for decoding Rowwise Data */ kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ + kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTENumTensorParams }; @@ -96,10 +97,9 @@ enum NVTEScalingMode { */ NVTE_BLOCK_SCALING_1D = 2, NVTE_BLOCK_SCALING_2D = 3, - /*! Single NVFP4 scale per block of 16 contiguous elements in forward pass (FWD), - and single MXFP8 scale per block of 32 contiguous elements in backward pass (BWD). - */ - NVTE_FWD_NVFP4_BWD_MXFP8_SCALING = 4, + /*! Single scale per block of 16 elements consecutive in either + * rowwise or columnwise direction */ + NVTE_NVFP4_1D_SCALING = 4, NVTE_INVALID_SCALING = 100 }; @@ -338,6 +338,12 @@ enum NVTEQuantizationConfigAttribute { * likely be refactored away in the future. */ kNVTEQuantizationConfigFloat8BlockScaleTensorFormat = 3, + /*! RNG state (NVTETensor with 2 elements - seed and offset */ + kNVTEQuantizationConfigRNGState = 4, + /*! Whether to use 2D block scaling for NVFP4 */ + kNVTEQuantizationConfigNVFP42DQuantization = 5, + /*! Whether to enable stochastic rounding */ + kNVTEQuantizationConfigStochasticRounding = 6, kNVTEQuantizationConfigNumAttributes }; @@ -449,6 +455,15 @@ inline bool is_fp4_dtype(const DType t) { return t == DType::kFloat4E2M1; } inline bool is_fp4_dtype(const DType t) { return false; } #endif // #ifndef __HIP_PLATFORM_AMD__ +/*! \brief Check if TE datatype is high precision (FP32, FP16, BF16) + * + * Return true if TE datatype is high precision + * \param[in] DType TE Datatype of interest + */ +inline bool is_high_precision_dtype(const DType t) { + return t == DType::kFloat32 || t == DType::kBFloat16 || t == DType::kFloat16; +} + /*! \struct TensorWrapper * \brief C++ wrapper for the NVTETensor class. */ @@ -584,6 +599,11 @@ class TensorWrapper { return set_parameter(kNVTEColumnwiseScaleInv, dptr, type, shape); } + template + TensorWrapper &set_columnwise_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEColumnwiseAmax, dptr, type, shape); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -608,6 +628,10 @@ class TensorWrapper { return get_parameter(kNVTEColumnwiseScaleInv); } + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEColumnwiseAmax); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -856,6 +880,24 @@ class QuantizationConfigWrapper { &format, sizeof(Float8BlockScaleTensorFormat)); } + /*! \brief Set stochastic rounding state */ + void set_rng_state(NVTETensor rng_state) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigRNGState, &rng_state, + sizeof(NVTETensor)); + } + + /*! \brief Set whether to use 2D block scaling for NVFP4 */ + void set_nvfp4_2d_quantization(bool nvfp4_2d_quantization) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP42DQuantization, + &nvfp4_2d_quantization, sizeof(bool)); + } + + /*! \brief Set whether to use stochastic rounding */ + void set_stochastic_rounding(bool stochastic_rounding) { + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigStochasticRounding, + &stochastic_rounding, sizeof(bool)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index cfcd91646..a5278522c 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -29,7 +29,7 @@ #ifndef __HIP_PLATFORM_AMD__ #include "../cudnn_utils.h" #else -#include "../util/rocm_cast_kernels.cuh" +#include "../cast/mxfp8/quantize_mxfp8.cuh" #endif #include "../util/system.h" @@ -447,10 +447,10 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) const size_t scale_dim_X_rowwise = 32; const size_t scale_dim_Y_colwise = launch_params.training ? 32 : 1; - const size_t chunks_Y = DIVUP(rows, transformer_engine::MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, transformer_engine::MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, transformer_engine::MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, transformer_engine::MXFP8_CHUNKS_PER_BLOCK_X); + const size_t chunks_Y = DIVUP(rows, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_Y); + const size_t chunks_X = DIVUP(cols, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNK_DIM_X); + const size_t blocks_Y = DIVUP(chunks_Y, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_Y); + const size_t blocks_X = DIVUP(chunks_X, dispatch::mxfp8::quantize_kernel::MXFP8_CHUNKS_PER_BLOCK_X); const size_t scale_stride_rowwise = launch_params.z_tensor->scale_inv.shape[1]; const size_t scale_stride_colwise = launch_params.training ? launch_params.z_tensor->columnwise_scale_inv.shape[1] : 1; @@ -459,17 +459,18 @@ void rocm_norm_mxfp8_quantize(LaunchParams &launch_params) e8m0_t *const scales_colwise_ptr = launch_params.training ? reinterpret_cast(launch_params.z_tensor->columnwise_scale_inv.dptr) : nullptr; - const dim3 block(transformer_engine::MXFP8_THREADS_PER_CHUNK); + const dim3 block(dispatch::mxfp8::quantize_kernel::MXFP8_THREADS_PER_CHUNK); const dim3 grid(blocks_X, blocks_Y); + using namespace dispatch::mxfp8::quantize_kernel; TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( scale_dim_Y_colwise, SCALE_DIM_Y, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( launch_params.z_tensor->dtype(), OType, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(compute_t))), IS_ALIGNED, - cast_mxfp8_2D_kernel<<>>( + quantize_mxfp8_kernel<<>>( reinterpret_cast(launch_params.params.z), nullptr, reinterpret_cast(launch_params.z_tensor->data.dptr), diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 8c6fccfb5..6c21eab7b 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -30,7 +30,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -66,7 +66,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size NVTE_Norm_Backend norm_backend; bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index f0cefaded..f5937379c 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -151,7 +151,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( if (requires_amax) { __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(temp_output)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(temp_output); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } } if (params.fp8_out) { temp_output = temp_output * scale; @@ -321,7 +328,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(z_ij)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(z_ij); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } if (params.fp8_out) { z.data.elt[jt] = z_ij * scale; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 6c85cc432..598e0ca08 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -26,7 +26,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens Tensor *rsigma, Tensor *workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream) { if (is_fp8_dtype(z->data.dtype) && !is_delayed_tensor_scaling(z->scaling_mode) && - !is_mxfp_scaling(z->scaling_mode)) { + !is_mxfp8_scaling(z->scaling_mode)) { NVTE_ERROR("Not implemented scaling mode: " + to_string(z->scaling_mode) + "."); } @@ -52,7 +52,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens NVTE_Norm_Backend norm_backend; bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ - bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index d057f9beb..9044a4202 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -144,7 +144,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke if (requires_amax) { __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(temp_output)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(temp_output)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(temp_output); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } } if (params.fp8_out) { temp_output = temp_output * scale; @@ -297,7 +304,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(z_ij)); + if (params.fp8_out) { + // For fp8_out, keep amax on pre-scale compute_t + amax = fmaxf(amax, fabsf(z_ij)); + } else { + // Otherwise compute amax on the value converted to output_t (e.g., bf16) + output_t out_t_val = output_t(z_ij); + amax = fmaxf(amax, fabsf(compute_t(out_t_val))); + } if (params.fp8_out) { z.data.elt[jt] = z_ij * scale; } diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 466c2e605..223f7a720 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -1,21 +1,27 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """This module provides predefined FP8 recipes.""" from __future__ import annotations -import warnings import os from enum import Enum -from typing import Optional, Union, Callable, NamedTuple -from typing_extensions import Literal +from typing import Any, Literal, Optional, Union, Callable, NamedTuple +from dataclasses import field from pydantic.dataclasses import dataclass -from transformer_engine.common import is_fp8_fnuz +from transformer_engine.common import is_fp8_fnuz, te_rocm_build class _FormatHelper(NamedTuple): + """ + Stores max FP8 values for fprop and bprop a `Format`. + """ + max_fwd: float + max_bwd: float + +class _FormatHelperFP8(NamedTuple): """ Stores max FP8 values for fprop and bprop a `Format`. """ @@ -40,9 +46,12 @@ class _FormatMaxVals(Enum): class Format(Enum): """ Supported FP8 formats. + Supported FP4 formats. Values ------ + E2M1 : + All FP4 tensors are in e2m1 format E4M3 : All FP8 tensors are in e4m3 format E5M2 : @@ -51,16 +60,21 @@ class Format(Enum): FP8 tensors in the forward pass are in e4m3 format, FP8 tensors in the backward pass are in e5m2 format """ - E4M3 = _FormatHelper(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) - E5M2 = _FormatHelper(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) - HYBRID = _FormatHelper(fwd=E4M3.fwd, bwd=E5M2.bwd) + E2M1 = _FormatHelper(max_fwd=6, max_bwd=6) + E4M3 = _FormatHelperFP8(fwd=_FormatMaxVals.E4M3.value, bwd=_FormatMaxVals.E4M3.value) + E5M2 = _FormatHelperFP8(fwd=_FormatMaxVals.E5M2.value, bwd=_FormatMaxVals.E5M2.value) + HYBRID = _FormatHelperFP8(fwd=E4M3.fwd, bwd=E5M2.bwd) @dataclass(frozen=True) class MMParams: - """for pytorch as an example, _scaled_mm use_fast_accum = (not use_split_accumulator) - apply split accumulator or not, turning it on will increase accuracy but impact gemm performance, - so only turn it on for certain gemms + """Matrix multiplication options. + + Parameters + ---------- + use_split_accumulator : bool, default = `True` + Use FP8 fast accumulation on Hopper or Ada. For more details, + see CUBLASLT_MATMUL_DESC_FAST_ACCUM option for cublasLtMatmul. """ use_split_accumulator: bool = True @@ -71,10 +85,24 @@ class QParams: """Quantization parameters. power_2_scale: use power of 2 scale parameter amax_epsilon: optional minimum value of abs max + random_hadamard_transform: whether to use random hadamard transform + stochastic_rounding: whether to use stocastic rounding """ power_2_scale: bool = False amax_epsilon: float = 0.0 + random_hadamard_transform: bool = False + stochastic_rounding: bool = False + fp4_2d_quantization: bool = False + + def __repr__(self) -> str: + return ( + f"Qparams(\npower_2_scale={self.power_2_scale},\n" + f"amax_epsilon={self.amax_epsilon},\n" + f"random_hadamard_transform={self.random_hadamard_transform},\n" + f"stochastic_rounding={self.stochastic_rounding},\n" + f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + ) class Recipe: @@ -82,6 +110,10 @@ class Recipe: Base recipe class. """ + def nvfp4(self): + """Whether the given recipe is NVFP4 1D block scaling.""" + return isinstance(self, NVFP4BlockScaling) + def mxfp8(self): """Whether the given recipe is MXFP8 block scaling.""" return isinstance(self, MXFP8BlockScaling) @@ -102,6 +134,10 @@ def float8_block_scaling(self): """Whether the given recipe is float8 blockwise scaling.""" return isinstance(self, Float8BlockScaling) + def custom(self): + """Whether the given recipe is custom.""" + return isinstance(self, CustomRecipe) + @dataclass() class DelayedScaling(Recipe): @@ -147,7 +183,7 @@ def scaling_factor_compute(amax: Tensor, where `Tensor` is a framework tensor type. reduce_amax: bool, default = `True` By default, if `torch.distributed` is initialized, the `amax` value for FP8 - tensors is reduced across the `fp8_group` (specified in the `fp8_autocast` + tensors is reduced across the `amax_reduction_group` (specified in the `autocast` call). This keeps the amaxes and scaling factors synced across the given distributed group. If set to `False`, this reduction is skipped and every GPU maintains local amaxes and scaling factors. To ensure results are @@ -155,7 +191,7 @@ def scaling_factor_compute(amax: Tensor, ranks must checkpoint in order to store the local tensors. fp8_dpa: bool, default = `False` Whether to enable FP8 dot product attention (DPA). When the model is placed in an - `fp8_autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the + `autocast(enabled=True)` region and `fp8_dpa` is set to `True`, DPA casts the inputs from higher precision to FP8, performs attention in FP8, and casts tensors back to higher precision as outputs. FP8 DPA currently is only supported in the `FusedAttention` backend. @@ -200,6 +236,7 @@ def __repr__(self) -> str: f"margin={self.margin}, " f"format={str(self.fp8_format).split('.')[1]}, " f"amax_history_len={self.amax_history_len}, " + f"reduce_amax={self.reduce_amax}, " f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) @@ -217,10 +254,11 @@ class Float8CurrentScaling(Recipe): pass. """ + use_power_2_scales: bool = os.getenv("NVTE_FP8_CURRENT_SCALING_POWER_2_SCALES", "0") == "1" fp8_format: Format = Format.HYBRID - fp8_quant_fwd_inp = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_fwd_weight = QParams(power_2_scale=False, amax_epsilon=0.0) - fp8_quant_bwd_grad = QParams(power_2_scale=False, amax_epsilon=0.0) + fp8_quant_fwd_inp = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_fwd_weight = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) + fp8_quant_bwd_grad = QParams(power_2_scale=use_power_2_scales, amax_epsilon=0.0) fp8_gemm_fprop: MMParams = MMParams(use_split_accumulator=False) fp8_gemm_dgrad: MMParams = MMParams(use_split_accumulator=True) fp8_gemm_wgrad: MMParams = MMParams(use_split_accumulator=True) @@ -229,9 +267,6 @@ class Float8CurrentScaling(Recipe): def __post_init__(self) -> None: assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." - assert ( - not self.fp8_dpa and not self.fp8_mha - ), "FP8 attention is not supported for Float8CurrentScaling." def __repr__(self) -> str: return ( @@ -350,6 +385,7 @@ def __post_init__(self) -> None: assert ( not self.fp8_dpa and not self.fp8_mha ), "FP8 attention is not supported for Float8BlockScaling." + assert self.fp8_format != Format.E5M2, "Pure E5M2 training is not supported." def __repr__(self) -> str: return ( @@ -367,3 +403,134 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}" ) + + +@dataclass() +class NVFP4BlockScaling(Recipe): + """ + Use the NVFP4 scaling strategy. + + This is a 2-level block scaling strategy. In level 1, each group of + 16 consecutive values is scaled together using their own scaling + factor. The type of the scaling factor is E4M3 (4 bits of exponent, + 3 bits of mantissa). In level 2, a global per tensor FP32 scaling + factor is used to scale the entire tensor. + + Since the scaling happens in a particular direction (either rowwise + or columnwise), in this recipe the quantized tensor and its transpose + are not numerically equivalent. Due to this, when Transformer Engine + needs both the tensor and its transpose (e.g. to calculate both + forward and backward pass), during the quantization both versions are + computed from the high precision input to avoid double quantization + errors. + + The default NVFP4 training recipe implements 3 techniques for quantizing + to a narrow format (4-bit): + + - For weight tensors a variant of the NVFP4 quantization is used, + where a single scaling factor is shared by a 2D block of 16x16 elements. + - When quantizing gradients, stochastic rounding is applied to avoid the bias + introduced by quantization. With this, values are rounded probabilistically + to one of their two nearest representable numbers, with probabilities + inversely proportional to their distances. + - When quantizing inputs and gradients, random Hadamard transforms are applied + (16x16 Hadamard matrix) to smooth outliers in the tensor distributions + and make them easier to represent accurately in NVFP4. + + These techniques are described more comprehensively in the NVFP4 paper titled + 'Pretraining Large Language Models with NVFP4' (https://arxiv.org/abs/2509.25149v1). + + Parameters + ---------- + fp4_format : {Format.E2M1}, default = Format.E2M1 + FP4 data type. + disable_rht : bool, default = `False` + If set to `True`, random Hadamard transforms are not applied to any tensor. + disable_stochastic_rounding : bool, default = `False` + If set to `True`, stochastic rounding is disabled during quantization for all tensors. + disable_2d_quantization : bool, default = `False` + If set to `True`, 1D block scaling with block size 16 is used for all tensors. + """ + + # Configuration envvars + disable_rht: bool = os.getenv("NVTE_NVFP4_DISABLE_RHT", "0") == "1" + disable_stochastic_rounding: bool = ( + os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" + ) + disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + + fp4_format: Format = Format.E2M1 + fp8_format: Format = Format.E4M3 + + # Not applying quantization to attention for now + fp8_dpa: bool = False + fp8_mha: bool = False + + def __post_init__(self) -> None: + assert self.fp4_format == Format.E2M1, "Only E2M1 is supported for NVFP4 scaling" + assert self.fp8_format == Format.E4M3, "Only E4M3 is supported for NVFP4 scaling" + + # Quantization params + # Note: RHT is currently only applied to column-wise usage so that + # it can be used for wgrad GEMM. + self.fp4_quant_fwd_inp = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=False, + fp4_2d_quantization=False, + ) + self.fp4_quant_fwd_weight = QParams( + random_hadamard_transform=False, + stochastic_rounding=False, + fp4_2d_quantization=not self.disable_2d_quantization, + ) + self.fp4_quant_bwd_grad = QParams( + random_hadamard_transform=not self.disable_rht, + stochastic_rounding=not self.disable_stochastic_rounding, + fp4_2d_quantization=False, + ) + + def __repr__(self) -> str: + return ( + f"recipe_type={self.__class__.__name__}, " + f"fp4_format={str(self.fp4_format).split('.')[1]}, " + f"fp8_format={str(self.fp8_format).split('.')[1]}, " + f"fp8_dpa={self.fp8_dpa}, " + f"fp8_mha={self.fp8_mha}, " + f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " + f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " + f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " + ) + + +@dataclass() +class CustomRecipe(Recipe): + """ + Custom recipe that allows users to provide quantizer factories. + + .. warning:: + **EXPERIMENTAL**: Custom recipe is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + + Parameters + ---------- + qfactory : Callable + Factory callable that returns a quantizer instance for a + given semantic tensor role. + The callable is typically invoked as: + qfactory( + role: str, + ) + + Where `role` is one of the following strings for e.g. te.Linear + (stable public contract): + - forward: "linear_input", "linear_weight", "linear_output" + - backward: "linear_grad_output", "linear_grad_input" + """ + + qfactory: Callable[..., Any] + + fp8_dpa: bool = False + fp8_mha: bool = False + + def __repr__(self) -> str: + return f"recipe_type={self.__class__.__name__}, qfactory={self.qfactory}" diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index f3c6b7952..55aa2907e 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -51,6 +51,13 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax, #endif +__launch_bounds__(1) __global__ void zero_amax_kernel(float *amax_ptr, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *amax_ptr = 0; +} + template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, @@ -118,7 +125,8 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, #endif const float *noop_ptr, cudaStream_t stream) { // Zero out amax so we can update with atomic max - NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); + zero_amax_kernel<<<1, 1, 0, stream>>>(amax, noop_ptr); + NVTE_CHECK_CUDA(cudaGetLastError()); // Return immediately if tensor is empty if (N == 0) { @@ -220,15 +228,17 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *convertNVTETensorCheck(output_); - NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, - "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING || + output.scaling_mode == NVTE_NVFP4_1D_SCALING, + "Output tensor for amax computation must be FP8 tensor with per-tensor scaling or " + "NVFP4 1D scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); NVTE_CHECK(output.amax.numel() == 1, "Output tensor for amax computation has invalid amax tensor " "(expected 1 entry, got shape=", output.amax.shape, ")"); - NVTE_CHECK(output.amax.dptr != nullptr, + NVTE_CHECK(output.amax.dptr != nullptr || output.columnwise_amax.dptr != nullptr, "Output tensor for amax computation has amax tensor without data"); NVTE_CHECK(output.amax.dtype == DType::kFloat32, "Output tensor for amax computation has invalid amax tensor " @@ -264,14 +274,16 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt } // Compute amax + float *amax_ptr = reinterpret_cast( + (output.amax.dptr != nullptr) ? output.amax.dptr : output.columnwise_amax.dptr); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); - launch_amax_kernel(reinterpret_cast(input.data.dptr), - reinterpret_cast(output.amax.dptr), input.data.numel(), + launch_amax_kernel( + reinterpret_cast(input.data.dptr), amax_ptr, input.data.numel(), #ifdef __HIP_PLATFORM_AMD__ - block_amax, block_capacity, + block_amax, block_capacity, #endif - noop_ptr, stream);); // NOLINT(*) + noop_ptr, stream);); // NOLINT(*) } } // anonymous namespace diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu new file mode 100644 index 000000000..5ebc7ba4f --- /dev/null +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -0,0 +1,54 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include + +#include + +#include "../common.h" +#include "../utils.cuh" + +namespace transformer_engine { +namespace nvfp4_recipe { + +// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; +constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); + +// Kernel to compute alpha *= amax_A * amax_B / factor +__global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, + const float *amax_B, float *alpha_out) { + // factor is defined in the enclosing namespace + *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; +} + +} // namespace nvfp4_recipe +} // namespace transformer_engine + +void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_rowwise_amax_A, + const NVTETensor inpB, const bool use_rowwise_amax_B, + float alpha_in, NVTETensor alpha_out, + cudaStream_t stream) { + NVTE_API_CALL(nvte_nvfp4_compute_per_tensor_scale); + using namespace transformer_engine; + + auto *tA = convertNVTETensor(inpA); + auto *tB = convertNVTETensor(inpB); + auto *tOut = convertNVTETensor(alpha_out); + + void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; + void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; + void *alpha_ptr = tOut->data.dptr; + + // check for not null pointers + NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); + NVTE_CHECK(amax_B_ptr != nullptr, "amax_B_ptr is null"); + NVTE_CHECK(alpha_ptr != nullptr, "alpha_ptr is null"); + + nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( + alpha_in, reinterpret_cast(amax_A_ptr), + reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 499f7bcff..a00c30a9c 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -24,8 +24,10 @@ namespace { #define __ldg(x) (*(x)) #endif +constexpr int MXFP8_BLOCK_SIZE = 32; #ifndef __HIP_PLATFORM_AMD__ -constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr int NVFP4_BLOCK_SIZE = 16; + constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; @@ -36,7 +38,6 @@ constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; #else // HIPCC does not support __host__ qualifier for variables // and constexpr values do not need __device__ qualifier because they are compile-time constants -constexpr int MXFP8_BLOCK_SIZE = 32; constexpr int TB_DIM = 32; constexpr int NEW_SF_TILE_DIM_K = 16; constexpr int N_SF_PER_TD_PER_TILE = 4; @@ -333,8 +334,6 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ const int original_K = kernel_args.original_k_list[tensor_id]; constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); - constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; - constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; // Get block index in grid. Emulate 2D grid. const int num_tiles_k = K / SF_TILE_DIM_K; @@ -351,9 +350,11 @@ __global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_ } // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { - if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input->scaling_mode) + "."); - } + NVTE_CHECK( + input->scaling_mode == NVTE_MXFP8_1D_SCALING || input->scaling_mode == NVTE_NVFP4_1D_SCALING, + "Input tensor has invalid scaling mode (", to_string(input->scaling_mode), ")."); + NVTE_CHECK(is_fp8_dtype(input->dtype()) || is_fp4_dtype(input->dtype()), + "Input tensor has invalid dtype (", to_string(input->dtype()), ")."); // Do nothing if tensor is empty if (input->data.numel() == 0) { @@ -364,135 +365,151 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s CheckInputTensor(*output, "scaling_factor_output"); auto& scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_MXFP8_1D_SCALING || scaling_mode == NVTE_NVFP4_1D_SCALING, + "Unsupported scaling mode for swizzling."); + + bool nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; // 1D block scaling, row-wise or colum-wise - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - const int m = - input->has_data() ? input->scale_inv.shape[0] : input->columnwise_scale_inv.shape[1]; - const int k = - input->has_data() ? input->scale_inv.shape[1] : input->columnwise_scale_inv.shape[0]; - - constexpr int SF_TILE_DIM_M = 128; - constexpr int SF_TILE_DIM_K = 4; - - NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); - NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); - NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - if (output->has_data()) { - NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), - output->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); - } - if (output->has_columnwise_data()) { - NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), - output->columnwise_scale_inv.shape.end(), 1, - std::multiplies()), - "Input.columnwise_scale_inv size is not equal to " - "Output.columnwise_scale_inv size!"); + + int m, k; + if (input->has_data()) { + m = input->scale_inv.shape[0]; + k = input->scale_inv.shape[1]; + } else { + if (nvfp4) { + m = input->columnwise_scale_inv.shape[0]; + k = input->columnwise_scale_inv.shape[1]; + } else { + m = input->columnwise_scale_inv.shape[1]; + k = input->columnwise_scale_inv.shape[0]; } + } - int num_tiles_m = m / SF_TILE_DIM_M; - int num_tiles_k = k / SF_TILE_DIM_K; + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; - dim3 block_size(TB_DIM, TB_DIM); - if (input->has_data()) { - int vec_load_size = (num_tiles_k - 1) % 4 + 1; - /* there is no int3 and misaligned if using int4/int2 */ - if (vec_load_size == 3) vec_load_size = 1; - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_first_dim(); - const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 2: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - case 1: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_row_scaling_kernel - <<>>( - input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + if (output->has_data()) { + NVTE_CHECK(m * k == std::accumulate(output->scale_inv.shape.begin(), + output->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output->columnwise_scale_inv.shape.begin(), + output->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } + + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + + // For NVFP4, the scale inverse for tranposed data needs rowwise swizzle. + const bool rowwise_swizzle = input->has_data() || nvfp4; + const bool columnwise_swizzle = input->has_columnwise_data() && !nvfp4; + + dim3 block_size(TB_DIM, TB_DIM); + if (rowwise_swizzle) { + int vec_load_size = (num_tiles_k - 1) % 4 + 1; + /* there is no int3 and misaligned if using int4/int2 */ + if (vec_load_size == 3) vec_load_size = 1; + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + + int original_M, original_K; + void *input_scale_inv_ptr, *output_scale_inv_ptr; + + if (!nvfp4 || input->has_data()) { + int block_scale_size = nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + original_M = input->flat_first_dim(); + original_K = input->flat_last_dim() / block_scale_size; + input_scale_inv_ptr = input->scale_inv.dptr; + output_scale_inv_ptr = output->scale_inv.dptr; + } else { + original_M = input->flat_last_dim(); + original_K = input->flat_first_dim() / NVFP4_BLOCK_SIZE; + input_scale_inv_ptr = input->columnwise_scale_inv.dptr; + output_scale_inv_ptr = output->columnwise_scale_inv.dptr; } - if (input->has_columnwise_data()) { - int vec_load_size = (num_tiles_m - 1) % 4 + 1; - if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ - int n_tiles_in_tb = TB_DIM * vec_load_size; - dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); - int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); - const int original_M = input->flat_last_dim(); - const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; - switch (vec_load_size) { - case 4: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 2: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - case 1: -#ifndef __HIP_PLATFORM_AMD__ - NVTE_CHECK_CUDA( - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); -#endif - swizzle_col_scaling_kernel - <<>>(input->columnwise_scale_inv.dptr, - output->columnwise_scale_inv.dptr, m, - k, original_M, original_K); - break; - default: - NVTE_ERROR("Not valid vec_load_size."); - break; - } - NVTE_CHECK_CUDA(cudaGetLastError()); + + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_row_scaling_kernel + <<>>( + input_scale_inv_ptr, output_scale_inv_ptr, m, k, original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; } + } + if (columnwise_swizzle) { + int vec_load_size = (num_tiles_m - 1) % 4 + 1; + if (vec_load_size == 3) vec_load_size = 1; /* no int3 and misaligned if using int4/int2 */ + int n_tiles_in_tb = TB_DIM * vec_load_size; + dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); - // 2D block scaling - } else { - NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); + switch (vec_load_size) { + case 4: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 2: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + case 1: + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); + swizzle_col_scaling_kernel + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, k, + original_M, original_K); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -582,15 +599,20 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, } NVTE_CHECK_CUDA(cudaGetLastError()); } + void multi_tensor_swizzle_scaling_factors(const std::vector& input, std::vector& output, cudaStream_t stream) { auto num_tensors = input.size(); bool all_has_data = true; bool all_has_columnwise_data = true; + bool all_nvfp4 = true; for (size_t i = 0; i < num_tensors; i++) { - if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { - NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); - } + auto scaling_mode = input[i]->scaling_mode; + auto is_fp8 = is_fp8_dtype(input[i]->dtype()); + auto is_fp4 = is_fp4_dtype(input[i]->dtype()); + NVTE_CHECK( + (is_fp8 && is_mxfp8_scaling(scaling_mode)) || (is_fp4 && is_nvfp4_scaling(scaling_mode)), + "Not implemented scaling mode " + to_string(scaling_mode) + "."); // We don't allow empty tensors. They should be filtered out before calling this function. if (input[i]->data.numel() == 0) { NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); @@ -599,13 +621,17 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); all_has_data &= input[i]->has_data(); all_has_columnwise_data &= input[i]->has_columnwise_data(); + all_nvfp4 &= is_nvfp4_scaling(scaling_mode); } NVTE_CHECK(all_has_data || all_has_columnwise_data, "All tensors should have data or columnwise data."); + const bool rowwise_swizzle = all_has_data || all_nvfp4; + const bool columnwise_swizzle = all_has_columnwise_data && !all_nvfp4; + constexpr int SF_TILE_DIM_M = 128; constexpr int SF_TILE_DIM_K = 4; - if (all_has_data) { + if (rowwise_swizzle) { MultiSwizzleArgs kernel_args; kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; @@ -621,29 +647,60 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args.num_tensors = 0; vec_load_size = 4; } - const int m = input[i]->scale_inv.shape[0]; - const int k = input[i]->scale_inv.shape[1]; + + int m, k; + + if (all_has_data) { + m = input[i]->scale_inv.shape[0]; + k = input[i]->scale_inv.shape[1]; + } else { + NVTE_CHECK(all_nvfp4, "When doing rowwise swizzle with rowwise data, it has to be NVFP4"); + m = input[i]->columnwise_scale_inv.shape[0]; + k = input[i]->columnwise_scale_inv.shape[1]; + } NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); - NVTE_CHECK( - m * k == std::accumulate(output[i]->scale_inv.shape.begin(), - output[i]->scale_inv.shape.end(), 1, std::multiplies()), - "Input.scale_inv size is not equal to Output.scale_inv size!"); + + if (output[i]->has_data()) { + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + } + if (output[i]->has_columnwise_data()) { + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + } int num_tiles_k = k / SF_TILE_DIM_K; int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; // We use the minimum vec_load_size across all tensors. - vec_load_size = std::min(vec_load_size, vec_load_size_i); + // TODO(zhongbo): fix vec_load_size for NVFP4 + // Current unit test won't capture this issue, but in E2E + // using vec_load_size = 1 other than 1 will lead to mis-aligned + // address error in MOE training + vec_load_size = all_nvfp4 ? 1 : std::min(vec_load_size, vec_load_size_i); const int pos = kernel_args.num_tensors; - kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); - kernel_args.output_list[pos] = output[i]->scale_inv.dptr; kernel_args.m_list[pos] = m; kernel_args.k_list[pos] = k; - kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); - kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + if (!all_nvfp4 || all_has_data) { + int block_scale_size = all_nvfp4 ? NVFP4_BLOCK_SIZE : MXFP8_BLOCK_SIZE; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / block_scale_size; + } else { + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / NVFP4_BLOCK_SIZE; + } kernel_args.num_tensors++; } // Launch the remaining tensors @@ -653,7 +710,10 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, kernel_args, vec_load_size, true, stream); } - if (all_has_columnwise_data) { + if (columnwise_swizzle) { + // NVFP4 shouldn't end up here because it only needs rowwise swizzle + NVTE_CHECK(!all_nvfp4, "NVFP4 shouldn't end up here because it only needs rowwise swizzle"); + MultiSwizzleArgs kernel_args; kernel_args.num_tensors = 0; kernel_args.block_range[0] = 0; @@ -708,7 +768,7 @@ void multi_tensor_swizzle_scaling_factors(const std::vector& input, * WIP (Phuong): * - Opt for bank conflicts * - Adding swizzle for 2d-block scaling. -*/ + */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_swizzle_scaling_factors); using namespace transformer_engine; diff --git a/transformer_engine/common/swizzle/swizzle_block_scaling.cu b/transformer_engine/common/swizzle/swizzle_block_scaling.cu new file mode 100644 index 000000000..4be85474a --- /dev/null +++ b/transformer_engine/common/swizzle/swizzle_block_scaling.cu @@ -0,0 +1,321 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace { +constexpr uint32_t WARP_SIZE = 32; +} // namespace +namespace swizzle_kernel_1d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +// Transposes a 4x4 matrix of bytes stored across four threads with consecutive thread ids where +// each thread stores a single row (of four bytes). +// Example: +// lane0.row = 0x00010203 +// lane1.row = 0x04050607 +// lane2.row = 0x08090a0b +// lane3.row = 0x0c0d0e0f +// Becomes: +// lane0.row = 0x0004080c +// lane1.row = 0x0105090d +// lane2.row = 0x02060a0e +// lane3.row = 0x03070b0f +uint32_t __device__ __forceinline__ transpose_4x4_byte_matrix(const uint32_t row, + const uint32_t lane, + const uint32_t active_mask) { + using cu = const uint32_t; + + // Threads operate in groups of 4, and each thread stores 4 bytes at a time. + // The bytes in this 4x4 matrix are labeled in hex. We shuffle around bytes + // until we have transposed the 4x4 matrix. + cu m_0123_4567_89ab_cdef = row; + cu m_4567_0123_cdef_89ab = __shfl_xor_sync(active_mask, m_0123_4567_89ab_cdef, 1, 4); + cu m_0426_4062_8cae_c8ea = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x6240); + cu m_5173_1537_d9fb_9dbf = __byte_perm(m_0123_4567_89ab_cdef, m_4567_0123_cdef_89ab, 0x3715); + cu m_0426_1537_8cae_9dbf = (lane & 1) ? m_5173_1537_d9fb_9dbf : m_0426_4062_8cae_c8ea; + cu m_8cae_9dbf_0426_1537 = __shfl_xor_sync(active_mask, m_0426_1537_8cae_9dbf, 2, 4); + cu m_048c_159d_8c04_9d15 = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x5410); + cu m_ae26_bf37_26ae_37bf = __byte_perm(m_0426_1537_8cae_9dbf, m_8cae_9dbf_0426_1537, 0x3276); + cu m_048c_159d_26ae_37bf = (lane & 2) ? m_ae26_bf37_26ae_37bf : m_048c_159d_8c04_9d15; + + return m_048c_159d_26ae_37bf; +} + +// Expands a uint32_t to a uint4 by duplicating each byte four times. +// Example: 0x01020304u becomes uint4{0x01010101, 0x02020202, 0x03030303, 0x04040404} +uint4 __device__ __forceinline__ broadcast_uint32_t_to_uint4(uint32_t x) { + return {__byte_perm(x, 0, 0x0000), __byte_perm(x, 0, 0x1111), __byte_perm(x, 0, 0x2222), + __byte_perm(x, 0, 0x3333)}; +} + +// Tag struct denoting whether the number of rows of the input fp8 block scaling tensor's data +// matrix is divisible by 128. If it is not, some threads could read out of bounds scaling factors. +struct no_oob_tag_t {}; +constexpr no_oob_tag_t NO_OOB_TAG; + +template +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride, + OOBT first_oob) { + // resolve kernel variant + constexpr bool no_oob = std::is_same_v; + static_assert(no_oob || std::is_same_v); + + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_x; + const uint32_t in_tile_x = out_tile_y; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = WARP_SIZE * sizeof(uint4); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factors for this lane's initial four 1x128 tiles + uint4 sf; + if constexpr (no_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + if ((out_tile_y < tiles_y - 1) || lane < first_oob) { + sf = reinterpret_cast(warp_src)[lane]; + } else { + sf = uint4{0, 0, 0, 0}; + } + } + + // pack the exponent bits of the scaling factors + uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1); + + // partially swizzle the scaling factors + constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches + const uint32_t lane_load_idx = (lane % 4) * 8 + (lane / 4); + packed_exponents = __shfl_sync(ACTIVE_MASK, packed_exponents, lane_load_idx); + + // transpose 4x4 matrices of scaling factors + packed_exponents = transpose_4x4_byte_matrix(packed_exponents, lane % 4, ACTIVE_MASK); + + // broadcast the scaling factors for sixteen 1x32 tiles + sf = broadcast_uint32_t_to_uint4(packed_exponents); + + // store them cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(uint4)), "Input scaling factor pointer must be aligned to ", + alignof(uint4), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + NVTE_CHECK(data_rows % 4 == 0, "Input tensor must not have any padding scaling factors"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 128x1 tile in the input scales + // and a 128x4 tile in the output scales. The input scales are in transposed order. + const uint32_t input_scale_inv_cols = DIVUP(data_rows, 4u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + const uint32_t first_oob = (input_scale_inv_cols % 128) / 4; + + if (first_oob == 0) { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, NO_OOB_TAG); + } else { + swizzle_block_scaling_1d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride, first_oob); + } +} +} // namespace swizzle_kernel_1d +namespace swizzle_kernel_2d { +constexpr uint32_t WARPS_X_PER_TB = 2; // configurable +constexpr uint32_t WARPS_Y_PER_TB = 2; // configurable + +void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE) + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel( + const void* __restrict__ const in, void* __restrict__ const out, const uint32_t tiles_x, + const uint32_t tiles_y, const uint32_t in_y_stride, const uint32_t out_y_stride) { + // load thread indices + const uint32_t lane = threadIdx.x; + __builtin_assume(lane < WARP_SIZE); + const uint32_t warp_x = threadIdx.z; + __builtin_assume(warp_x < WARPS_X_PER_TB); + const uint32_t warp_y = threadIdx.y; + __builtin_assume(warp_y < WARPS_Y_PER_TB); + + // compute tile indices + const uint32_t out_tile_y = blockIdx.y * WARPS_Y_PER_TB + warp_y; + const uint32_t out_tile_x = blockIdx.x * WARPS_X_PER_TB + warp_x; + const uint32_t in_tile_y = out_tile_y; + const uint32_t in_tile_x = out_tile_x; + + // bounds check; uniform branch + if (out_tile_y >= tiles_y || out_tile_x >= tiles_x) { + return; + } + + // calculate this warp's input base pointer + constexpr uint32_t in_x_stride = sizeof(float); + const void* const warp_src = in + in_tile_y * in_y_stride + in_tile_x * in_x_stride; + + // load scaling factor for this warp's 128x128 tile + uint32_t sf = *reinterpret_cast(warp_src); + + // broadcast it to four scaling factors for 1x32 tiles + sf = (sf << 1) | (sf >> 7); + sf = sf | (sf >> 16); + + // broadcast it to sixteen scaling factors for 1x32 tiles + const uint4 sf4{sf, sf, sf, sf}; + + // store it cooperatively for 512 1x32 tiles in a 128x128 tile + constexpr uint32_t out_x_stride = 512; + void* const warp_dst = out + out_tile_y * out_y_stride + out_tile_x * out_x_stride; + reinterpret_cast(warp_dst)[lane] = sf4; +} + +void launch_kernel(const void* const in, void* const out, uint32_t data_rows, uint32_t data_cols, + cudaStream_t stream) { + NVTE_CHECK(is_aligned_ptr(in, alignof(float)), "Input scaling factor pointer must be aligned to ", + alignof(float), " bytes"); + NVTE_CHECK(is_aligned_ptr(out, alignof(uint4)), + "Output scaling factor pointer must be aligned to ", alignof(uint4), " bytes"); + + const uint32_t tiles_x = DIVUP(data_cols, 128u); + const uint32_t tiles_y = DIVUP(data_rows, 128u); + const dim3 grid_dim{DIVUP(tiles_x, WARPS_X_PER_TB), DIVUP(tiles_y, WARPS_Y_PER_TB), 1}; + const dim3 block_dim{WARP_SIZE, WARPS_Y_PER_TB, WARPS_X_PER_TB}; + + // Each 128x128 tile in the data corresponds to a 1x1 tile in the input scales + // and a 128x4 tile in the output scales. + const uint32_t input_scale_inv_cols = DIVUP(data_cols, 512u) * 4; + const uint32_t output_scale_inv_cols = tiles_x * 128 * 4; + const uint32_t in_y_stride = input_scale_inv_cols * sizeof(float); + const uint32_t out_y_stride = output_scale_inv_cols * sizeof(uint8_t); + + swizzle_block_scaling_2d_to_mxfp8_scaling_factors_kernel<<>>( + in, out, tiles_x, tiles_y, in_y_stride, out_y_stride); +} +} // namespace swizzle_kernel_2d + +void swizzle_block_scaling_to_mxfp8_scaling_factors(const Tensor* input, Tensor* output, + cudaStream_t stream) { + // Do nothing if tensor is empty + if (input->data.numel() == 0) { + return; + } + + CheckInputTensor(*input, "block_scaling_scaling_factor_input"); + CheckInputTensor(*output, "mxfp8_scaling_factor_output"); + + const NVTEScalingMode scaling_mode = input->scaling_mode; + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + NVTE_CHECK(output->scaling_mode == NVTE_MXFP8_1D_SCALING, + "Output tensor must be an mxfp8 tensor"); + + NVTE_CHECK(input->data.dtype == transformer_engine::DType::kFloat8E4M3 || + input->data.dtype == transformer_engine::DType::kFloat8E5M2, + "Input data must have FP8E4M3 or FP8E5M2 dtype to be compatible with MXFP8"); + NVTE_CHECK(output->data.dtype == input->data.dtype, + "Output data must have the same dtype as input data"); + NVTE_CHECK(input->scale_inv.dtype == DType::kFloat32, "Input must have FP32 scaling factors"); + NVTE_CHECK(output->scale_inv.dtype == DType::kFloat8E8M0, + "Output must have E8M0 scaling factors"); + + NVTE_CHECK(input->data.dptr != nullptr, "Input must have rowwise data"); + NVTE_CHECK(output->data.dptr == input->data.dptr, "Output must share data with input"); + NVTE_CHECK(input->scale_inv.dptr != nullptr, "Input must have rowwise scaling factors"); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Output must have rowwise scaling factors"); + + NVTE_CHECK(input->data.shape.size() == 2, "Input data must be a matrix"); + NVTE_CHECK(output->data.shape == input->data.shape, + "Output data must have the same shape as input data"); + NVTE_CHECK(input->scale_inv.shape.size() == 2, "Input scaling factors must be a matrix"); + NVTE_CHECK(output->scale_inv.shape.size() == 2, "Output scaling factors must be a matrix"); + + const size_t data_rows = input->data.shape[0]; + const size_t data_cols = input->data.shape[1]; + const size_t input_scale_inv_rows = input->scale_inv.shape[0]; + const size_t input_scale_inv_cols = input->scale_inv.shape[1]; + const size_t output_scale_inv_rows = output->scale_inv.shape[0]; + const size_t output_scale_inv_cols = output->scale_inv.shape[1]; + + NVTE_CHECK(output_scale_inv_rows == DIVUP(data_rows, 128) * 128, + "Expected the output scaling factor matrix to have ", + DIVUP(data_rows, 128) * 128, " rows, but it has ", output_scale_inv_rows, + " rows instead."); + NVTE_CHECK(output_scale_inv_cols == DIVUP(data_cols, 128) * 4, + "Expected the output scaling factor matrix to have ", + DIVUP(data_cols, 128) * 4, " columns, but it has ", output_scale_inv_cols, + " columns instead."); + + if (scaling_mode == NVTE_BLOCK_SCALING_1D) { + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_cols, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_cols, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_rows, 4) * 4, + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 4) * 4, + " columns, but it has ", input_scale_inv_cols, " columns instead."); + + swizzle_kernel_1d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } else { // scaling_mode == NVTE_BLOCK_SCALING_2D + NVTE_CHECK(input_scale_inv_rows == DIVUP(data_rows, 128), + "Expected the input scaling factor matrix to have ", DIVUP(data_rows, 128), + " rows, but it has ", input_scale_inv_rows, " rows instead."); + NVTE_CHECK(input_scale_inv_cols == DIVUP(data_cols, 512) * 4, + "Expected the input scaling factor matrix to have ", + DIVUP(data_cols, 512) * 4, " columns, but it has ", input_scale_inv_cols, + " columns instead."); + + swizzle_kernel_2d::launch_kernel(input->scale_inv.dptr, output->scale_inv.dptr, data_rows, + data_cols, stream); + } +} + +} // namespace transformer_engine + +void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_swizzle_block_scaling_to_mxfp8_scaling_factors); + using namespace transformer_engine; + swizzle_block_scaling_to_mxfp8_scaling_factors(convertNVTETensorCheck(input), + convertNVTETensorCheck(output), stream); +} diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 68d1f0ec5..77b922b0a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include "common.h" #include "common/util/cuda_runtime.h" @@ -65,8 +66,12 @@ std::string to_string(const NVTEScalingMode &mode) { return "NVTE_DELAYED_TENSOR_SCALING"; case NVTE_MXFP8_1D_SCALING: return "NVTE_MXFP8_1D_SCALING"; - case NVTE_FWD_NVFP4_BWD_MXFP8_SCALING: - return "NVTE_FWD_NVFP4_BWD_MXFP8_SCALING"; + case NVTE_BLOCK_SCALING_1D: + return "NVTE_BLOCK_SCALING_1D"; + case NVTE_BLOCK_SCALING_2D: + return "NVTE_BLOCK_SCALING_2D"; + case NVTE_NVFP4_1D_SCALING: + return "NVTE_NVFP4_1D_SCALING"; case NVTE_INVALID_SCALING: return "NVTE_INVALID_SCALING"; } @@ -96,12 +101,11 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { t.columnwise_scale_inv.shape, ")"); } } else { - if (t.scaling_mode == NVTE_MXFP8_1D_SCALING || - t.scaling_mode == NVTE_FWD_NVFP4_BWD_MXFP8_SCALING) { + if (t.scaling_mode == NVTE_MXFP8_1D_SCALING) { // Need (4, 128) alignment even for e8 scaling factor auto block_alignment = std::vector{128ul, 4ul}; size_t expected_x, expected_y, alignment; - const size_t block_size_rowwise = (t.scaling_mode == NVTE_MXFP8_1D_SCALING) ? 32 : 16; + const size_t block_size_rowwise = 32; const size_t block_size_colwise = 32; if (t.has_data()) { @@ -112,6 +116,7 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(block_size_rowwise)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, "\" has invalid scale_inv shape (expected ", expected, ", got ", @@ -124,11 +129,29 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name) { alignment; alignment = block_alignment[0]; expected_y = DIVUP(DIVUP(t.flat_last_dim(), static_cast(1)), alignment) * alignment; + const auto &expected = std::vector{expected_x, expected_y}; NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", t.columnwise_scale_inv.shape, ")"); } + } else if (t.scaling_mode == NVTE_NVFP4_1D_SCALING) { + if (t.has_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_first_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_last_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid scale_inv shape (expected ", expected, ", got ", + t.scale_inv.shape, ")"); + } + if (t.has_columnwise_data()) { + const size_t expected_y = DIVUP_TO_MULTIPLE(t.flat_last_dim(), 128); + const size_t expected_x = DIVUP_TO_MULTIPLE(DIVUP(t.flat_first_dim(), 16lu), 4); + const auto &expected = std::vector{expected_y, expected_x}; + NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name, + "\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ", + t.columnwise_scale_inv.shape, ")"); + } } } } @@ -156,6 +179,26 @@ void CheckInputTensor(const Tensor &t, const std::string &name) { "(expected Float32 or Byte, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // TODO(ksivaman): Fix this to check for amaxes and other details. + // For now only needed for swizzle. + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor input ", name, + "_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor input ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP8 scaling factor input ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected DType::kFloat8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input ", name); NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input ", name); @@ -197,10 +240,29 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt "(expected Float32 or Float8E8M0, got ", to_string(t.columnwise_scale_inv.dtype), ")"); } + } else if (is_fp4_dtype(type)) { + // FP4 output needs to have the scale_inv + if (t.has_data()) { + NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_scale_inverse must be allocated"); + NVTE_CHECK(t.scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", name, + "_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.scale_inv.dtype), ")"); + } + if (t.has_columnwise_data()) { + NVTE_CHECK(t.columnwise_scale_inv.dptr != nullptr, "FP4 scaling factor output ", name, + "_columnwise_scale_inverse must be allocated"); + NVTE_CHECK(t.columnwise_scale_inv.dtype == DType::kFloat8E4M3, "FP4 scaling factor output ", + name, + "_columnwise_scale_inverse has invalid dtype " + "(expected Float8E4M3, got ", + to_string(t.columnwise_scale_inv.dtype), ")"); + } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); - // Note: amax is supported for non-FP8 output as it can be fused into the computation - // and later used for quantization with no need to compute it separately + // Unfused quant with level 2 nvfp4 scaling will produce high precision tensors with amax. + // NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); @@ -493,6 +555,9 @@ void nvte_set_tensor_param(NVTETensor *tensor, NVTETensorParam param_name, case kNVTEColumnwiseScaleInv: t->columnwise_scale_inv = *param; break; + case kNVTEColumnwiseAmax: + t->columnwise_amax = *param; + break; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -516,6 +581,8 @@ NVTEBasicTensor nvte_get_tensor_param(const NVTETensor tensor, NVTETensorParam p return t.scale_inv; case kNVTEColumnwiseScaleInv: return t.columnwise_scale_inv; + case kNVTEColumnwiseAmax: + return t.columnwise_amax; default: NVTE_ERROR("Unknown tensor parameter!"); } @@ -631,6 +698,15 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigFloat8BlockScaleTensorFormat: std::memcpy(&config_.float8_block_scale_tensor_format, buf, attr_size); break; + case kNVTEQuantizationConfigRNGState: + std::memcpy(&config_.rng_state, buf, attr_size); + break; + case kNVTEQuantizationConfigNVFP42DQuantization: + std::memcpy(&config_.nvfp4_2d_quantization, buf, attr_size); + break; + case kNVTEQuantizationConfigStochasticRounding: + std::memcpy(&config_.stochastic_rounding, buf, attr_size); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index abfa226e8..89266f4bb 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -8,6 +8,7 @@ #define TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ #include "../common.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine::detail { @@ -62,6 +63,14 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor const bool pow_2_scale, const SimpleTensor &noop_tensor, cudaStream_t stream); +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor &input, const SimpleTensor &global_amax, SimpleTensor &scale_inv, + SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor &noop_tensor, cudaStream_t stream); + } // namespace transformer_engine::detail #endif // TRANSFORMER_ENGINE_COMMON_TRANSPOSE_CAST_TRANSPOSE_H_ diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index c3f085b87..661cf339a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,6 +14,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -485,6 +486,11 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow_2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index d38bf7996..fcf7a151c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,6 +17,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -529,6 +530,11 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); + if (transformer_engine::cuda::sm_arch() >= 100) { + NVTE_CHECK(pow2_scale, "On Blackwell and newer, the FP8 block scaling recipe is emulated ", + "with MXFP8, which requires using power of two scaling factors."); + } + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; size_t num_rows = 1; diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu new file mode 100644 index 000000000..b49a54fbd --- /dev/null +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -0,0 +1,839 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "common/common.h" +#include "common/recipe/recipe_common.cuh" +#include "common/transpose/cast_transpose.h" +#include "common/util/curanddx.hpp" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { + +#if CUDA_VERSION >= 12080 +namespace quantize_transpose_nvfp4 { +namespace { + +using std::int32_t; +using std::uint32_t; +using std::uint8_t; + +using transformer_engine::detail::TypeExtrema; + +// clang-format off +/* + +Step 1: Load input to shared memory +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 8 times +* What each thread does in each loop: + * 8 elements are read from the input at a time + * 2 elements are written to the shared memory at a time, for a total of 4 times ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 1 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| Warp 7 | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 8 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 2: Cast and store to output_c +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 4 times +* What each thread does in each loop: + * 2 elements are read from the shared memory at a time, for a total of 8 times + * Every 8 consecutive threads do reduction and calculate the amax of each row + * 16 elements are quantized and write to output_c at a time ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| T0 | T1 | T2 | T3 | T4 | T5 | T6 | T7 | +| T8 | T9 | T10 | T11 | T12 | T13 | T14 | T15 | +| T16 | T17 | T18 | T19 | T20 | T21 | T22 | T23 | +| T24 | T25 | T26 | T27 | T28 | T29 | T30 | T31 | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 1 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| | +| Warp 7 | +| | +| | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ +| ... | +| ... | +| ... | +| ... | +| Loop 4 times | +| ... | +| ... | +| ... | +| ... | ++-------------------------------+-------------------------------+-------------------------------+-------------------------------+ + +Step 3: Transpose, cast and store to output_t +* shard memory: 128x128 elements with type=InputType (below graph doesn't consider padding) +* 8 warps +* Loop 2 times +* What each thread does in each loop: + * 2 elements (in a row) are read from the shared memory at a time, for a total of 16 times + * Every 8 consecutive threads do reduction and calculate the amax of each column + * 16 elements are quantized and write to output_c at a time, for a total of 2 times ++------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+------8 elements-------+------8 elements-------+-----40 elements-------+------8 elements-------+ +| T0 | T8 | T16 | T24 | | | | T0 | T8 | T16 | T24 | | | | +| T1 | T9 | T17 | T25 | | | | T1 | T9 | T17 | T25 | | | | +| T2 | T10 | T18 | T26 | | | | T2 | T10 | T18 | T26 | | | | +| T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | T3 | T11 | T19 | T27 | Warp 1 | ... | Warp 7 | +| T4 | T12 | T20 | T28 | | | | T4 | T12 | T20 | T28 | | | | +| T5 | T13 | T21 | T29 | | | | T5 | T13 | T21 | T29 | | | | +| T6 | T14 | T22 | T30 | | | | T6 | T14 | T22 | T30 | | | | +| T7 | T15 | T23 | T31 | | | | T7 | T15 | T23 | T31 | | | | ++-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+-----------------------+ + +*/ +// clang-format on + +constexpr int kThreadsPerWarp = 32; + +// for fp4, we use uint8_t to store 2 fp4 numbers +constexpr int kNFP4PerContainer = 2; + +// Hyperparameters for performance tuning +constexpr int kTileDim = 128; +// constexpr int kScaleDim = 32; +constexpr int kNVecIn = 8; // The number of elements each LDG touches +constexpr int kNVecOut = 16; // The number of elements each STG touches +constexpr int kNVecSMem = 2; // The number of elements each LDS/STS touches +constexpr int kThreadsPerBlock = 256; // Thread block size, 8 warps in total + +// Auto-calculated constants, do not modify directly) +static_assert(kNVecIn % kNVecSMem == 0, "kNVecIn must be divisible by kNVecSMem"); +static_assert(kNVecOut % kNVecSMem == 0, "kNVecOut must be divisible by kNVecSMem"); +constexpr int kSMemRow = kTileDim; +constexpr int kSMemCol = (kTileDim / kNVecSMem) + 1; +constexpr int kSMemSize = kSMemRow * kSMemCol * kNVecSMem; +constexpr int kNumThreadsLoad = kTileDim / kNVecIn; // 16 +constexpr int kNumThreadsStore = kTileDim / kNVecOut; // 8 +// constexpr int kNumThreadsReduce = kScaleDim / kNVecOut; +static_assert(kNumThreadsLoad <= kThreadsPerWarp, "kNumThreadsLoad must be <= kThreadsPerWarp"); +static_assert(kNumThreadsStore <= kThreadsPerWarp, "kNumThreadsStore must be <= kThreadsPerWarp"); + +// for 2D block scaling, we need to reduce amax in warp +static __device__ constexpr unsigned int WARP_REDUCE_AMAX_GROUP_MASKS[8] = { + 0x01010101, 0x02020202, 0x04040404, 0x08080808, 0x10101010, 0x20202020, 0x40404040, 0x80808080}; + +// max for every group_size elements in warp +template +__device__ __forceinline__ float groupMax(float val, unsigned int groupMask) { + for (int offset = group_size / 2; offset > 0; offset /= 2) { + val = max(val, __shfl_down_sync(groupMask, val, offset * shfl_down_stride)); + } + return val; +} + +template +__device__ __forceinline__ ScaleType ComputeDecodeScaleFP4(const float amax, + const float global_encode_scale) { + float decode_scale = amax / TypeExtrema::max; + decode_scale = decode_scale * global_encode_scale; + decode_scale = fminf(decode_scale, TypeExtrema::max); + return static_cast(decode_scale); +} + +template +__device__ __forceinline__ float ComputeEncodeScaleFP4(ScaleType decode_scale, + const float global_decode_scale) { + return fminf(1.0f / (static_cast(decode_scale) * global_decode_scale), + TypeExtrema::max); +} + +template +__device__ __forceinline__ float ComputeOutputFP4(IType input, float encode_scale) { + return static_cast(input) * encode_scale; +} + +__device__ __forceinline__ float ComputeGlobalEncodeScaleFP4(const float global_amax) { + constexpr float fp8_max = TypeExtrema::max; + constexpr float fp4_max = TypeExtrema::max; + float global_encode_scale = fp8_max * fp4_max / global_amax; + // If scale is infinity, return max value of float32 + global_encode_scale = fminf(global_encode_scale, TypeExtrema::max); + // If global amax is 0 or infinity, return 1 + if (global_amax == 0.f || global_encode_scale == 0.f) { + return 1.f; + } + return global_encode_scale; +} + +__device__ __forceinline__ uint32_t +get_rbits(transformer_engine::curanddx::detail::philox4x32_native_state<10>& + rng, // philox4x32_native_state<10>: 10 rounds of philox4_32 + uint4& random_uint4, int& rnd_idx) { + if (rnd_idx == 4) { + rnd_idx = 0; + random_uint4 = rng.generate4(); + } + + // Treat uint4 as an array of 4x uint32_t elements for indexing + const uint32_t* const rbits_arr = reinterpret_cast(&random_uint4); + const uint32_t rbits = rbits_arr[rnd_idx++]; + return rbits; +} + +template +__device__ __forceinline__ size_t scale_factor_swizzled_offset(size_t row_idx, size_t col_idx, + uint32_t col_length) { + // This function takes in indices from the scale factor matrix and returns an offset in the + // swizzled format. row_idx, col_idx are original indices from the scale factor matrix (unswizzled + // index). col_length is the column length of the scale factor matrix. tile_scales_inv is the + // pointer to the scale factor matrix. + + // https://github.com/NVIDIA/cutlass/blob/main/media/docs/cpp/blackwell_functionality.md#scale-factor-layouts + // For any scale factor matrix, it's 512B base block. Each base block consists of 128 rows and 4 + // columns. Base block is divided into 4 column blocks, each column block has 32 rows and 4 + // columns. + + // NOTE: There are not a lot of good illustrations about the swizzled scale factor matrix. + // To think in high level, the swizzled scale factor matrix could be composed as: + // unswizzled_scale_factor_matrix = torch.empty((M, N // 16), dtype=torch.uint8) + // cbg_cnt = N // 16 // 4 # Assuming N is divisible by 64 + // rb_cnt = M // 128 # Assuming M is divisible by 128 + // tmp = unswizzled_scale_factor_matrix.reshape(rb_cnt, 4, 32, cbg_cnt, 4) + // tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + // swizzled_scale_factor_matrix = tmp.reshape((-1, 128, 4)) + + constexpr uint32_t kTotalRowsPerBaseBlock = 128; + constexpr uint32_t kRowsPerBaseBlockCol = 32; + constexpr uint32_t kColsPerBaseBlockCol = 4; + + const size_t rb = row_idx / kTotalRowsPerBaseBlock; + const size_t rem = row_idx % kTotalRowsPerBaseBlock; + const size_t d4 = rem / kRowsPerBaseBlockCol; + const size_t d3 = rem % kRowsPerBaseBlockCol; + const size_t cbg = col_idx / kColsPerBaseBlockCol; + const size_t d5 = col_idx % kColsPerBaseBlockCol; + + const size_t cbg_cnt = DIVUP(col_length, kColsPerBaseBlockCol); + // row-major offset in the logical shape + // (rb_cnt , cbg_cnt , 32 , 4 , 4) + // Magic number 16 below comes from the fact we have kColsPerBaseBlockCol = 4, and d4 ([0-128] / + // 32 = [0-4]) + return ((rb * cbg_cnt + cbg) * kRowsPerBaseBlockCol + d3) * 16 + d4 * kColsPerBaseBlockCol + d5; +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const uint32_t rbits) { + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + uint16_t out_4x; + asm volatile( + "{\n" + "cvt.rs.satfinite.e2m1x4.f32 %0, {%3, %4, %1, %2}, %5; \n\t" + "}" + : "=h"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x), "r"(rbits)); + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt.rs PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); + } +} + +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const uint32_t rbits) { + constexpr bool has_fp4 = ARCH_BLACKWELL_FAMILY; + if constexpr (has_fp4) { + // NOTE: rbits unused for rn. + uint32_t out_4x; // Only need 16 bit. Using 32 bit container for packing. + asm volatile( + "{\n" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, %1, %2;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, %3, %4;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "f"(in01.y), "f"(in01.x), "f"(in23.y), "f"(in23.x)); + return reinterpret_cast<__nv_fp4x4_e2m1*>(&out_4x)[0]; + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + uint16_t dummy = 0; + return *reinterpret_cast<__nv_fp4x4_e2m1*>(&dummy); + } +} + +template +__device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const uint32_t rbits) { + if constexpr (kApplyStochasticRounding) { + return cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, rbits); + } else { + return cvt_fp32_to_fp4_4x_with_rn(in01, in23, rbits); + } +} + +template +__global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( + const IType* const input, const float* global_amax, OType* const output_c, + OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, + const size_t row_length, const size_t num_rows, const size_t scale_stride_x, + const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, + const size_t kScaleBlockDim, const float epsilon, const size_t* rng_state, + const float* noop_ptr) { + constexpr int kNVecContainer = kNVecOut / kNFP4PerContainer; + using SMemVec = Vec; + using OVec = Vec; + union IVec { + Vec input_type; + Vec smem_type; + }; + + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + + const size_t block_idx_x = blockIdx.x; + const size_t block_idx_y = blockIdx.y; + const size_t rng_sequence = + threadIdx.x + block_idx_x * kThreadsPerBlock + block_idx_y * gridDim.x * kThreadsPerBlock; + const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0; + const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0; + + transformer_engine::curanddx::detail::philox4x32_native_state<10> rng; + rng.init(rng_seed, rng_sequence, rng_offset); + uint4 random_uint4 = kApplyStochasticRounding ? rng.generate4() : uint4{0, 0, 0, 0}; + + int rnd_idx = + 0; // Index of the random number. It increments each time when used and resets to 0 if reaches 4x + + extern __shared__ char smem_base[]; + SMemVec* smem = reinterpret_cast(&smem_base[0]); + + // 2D block scaling is not supported for E8 scaling MXFP4 or for colwise only mode. + // Instead of static_assert, return early if these invalid modes are detected. + if constexpr (kIs2DBlockScaling && kIsE8Scaling) { + return; + } + if constexpr (kIs2DBlockScaling && !kReturnIdentity) { + return; + } + // for 128x128 block, 2D block scaling means there will be 8x8 amax values for nvfp4, 4x4 for 2D mxfp4 + // use constexpr to define the size, when not using 2D, use minimal size 1x1 + constexpr int kFP4BlockScalingSize = 16; + constexpr int k2DBlockAmaxDim = kIs2DBlockScaling ? (kTileDim / kFP4BlockScalingSize) : 1; + constexpr int kNumRowsPerWarp = kThreadsPerWarp / kNumThreadsStore; // 4 + constexpr int k2DBlockAmaxReduceDim = + kIs2DBlockScaling ? (kFP4BlockScalingSize / kNumRowsPerWarp) : 1; + __shared__ CType amax_smem_red[k2DBlockAmaxDim][k2DBlockAmaxDim][k2DBlockAmaxReduceDim]; + __shared__ CType amax_smem[k2DBlockAmaxDim][k2DBlockAmaxDim]; + + // Step 1: Load input to shared memory + { + constexpr int r_stride = kThreadsPerBlock / kNumThreadsLoad; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsLoad) * (kNVecIn / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsLoad; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = (c_g < row_length ? min(static_cast(kNVecIn), row_length - c_g) + : 0); // For not aligned case + const IType* input_g = &input[r_g * row_length + c_g]; // Input address in global memory +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + IVec input_vec; + // Step 1.1: Load from global memory (input) to registers + if constexpr (kAligned) { + input_vec.input_type.load_from(input_g); + } else { + if (r_g < num_rows) { + input_vec.input_type.load_from_elts(input_g, 0, num_ele); + } else { + input_vec.input_type.clear(); + } + } + // Step 1.2: Write to shared memory +#pragma unroll + for (int i = 0; i < kNVecIn / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem[r * kSMemCol + c] = input_vec.smem_type.data.elt[i]; + } + // Step 1.3: Update input address, row index of shared memory, (and row index of global memory + // for not aligned case) + input_g += stride_g; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + __syncthreads(); + + const int kNumThreadsReduce = kScaleBlockDim / kNVecOut; + const float global_encode_scale = + kIsE8Scaling ? 1.0f : ComputeGlobalEncodeScaleFP4(global_amax[0]); + const float global_decode_scale = 1.0 / global_encode_scale; + + // Step 2: Cast and store to output_c + if constexpr (kReturnIdentity) { + constexpr int r_stride = + kThreadsPerBlock / kNumThreadsStore; // stride in rows of shared memory + constexpr int num_iterations = kTileDim / r_stride; + const int c_s = + (threadIdx.x % kNumThreadsStore) * (kNVecOut / kNVecSMem); // Column in shared memory + int r_s = threadIdx.x / kNumThreadsStore; // Row in shared memory + const size_t c_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Column in global memory + size_t r_g = block_idx_y * kTileDim + r_s; // Row in global memory + const size_t stride_g = static_cast(r_stride) * row_length; // Stride in global memory + const size_t num_ele = + (c_g < row_length ? min(static_cast(kNVecOut / kNFP4PerContainer), + (row_length - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_c[(r_g * row_length + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut / kNVecSMem]; + // Step 2.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { + int c = c_s + i; + int r = r_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } + // Step 2.2: Compute local amax + CType amax = 0; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; ++i) { +#pragma unroll + for (int j = 0; j < kNVecSMem; ++j) { + __builtin_assume(amax >= 0); + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[j])); + } + } + // Step 2.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // doing shuffle sync for 2D block scaling (not applicable for E8 scaling) + if constexpr (kIs2DBlockScaling) { + // first amax shuffle sync in warp, then reduce in smem + // T0 T8 T16 T24 should do amax reduction together + constexpr int kNumRowsPerIter = kThreadsPerBlock / kNumThreadsStore; // 32 + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + int tid_in_warp_x = threadIdx.x % kNumThreadsStore; + int tid_in_warp_y = (threadIdx.x / kNumThreadsStore) % kNumRowsPerWarp; + CType amax_warp_reduced = groupMax( + amax, WARP_REDUCE_AMAX_GROUP_MASKS[tid_in_warp_x]); + // now T0 ~ T8 in each warp has the reduced amax values + int data_row_idx = iter * kNumRowsPerIter + warp_idx * kNumRowsPerWarp + tid_in_warp_y; + if (tid_in_warp_y == 0) { + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] + [warp_idx % k2DBlockAmaxReduceDim] = amax_warp_reduced; + } + __syncthreads(); + + if (data_row_idx % kFP4BlockScalingSize == 0) { + CType amax_2d = 0.0; + for (int i = 0; i < k2DBlockAmaxReduceDim; i++) { + amax_2d = fmaxf(amax_2d, + amax_smem_red[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x][i]); + } + amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x] = amax_2d; + } + __syncthreads(); + // every thread now knows 2D amax + amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; + } + // Step 2.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 2.5: Write scale_inv + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g < num_rows); + write_scale_inv &= (c_g < row_length); + } + if (write_scale_inv) { + size_t row_idx = block_idx_y * kTileDim + r_s; + size_t col_idx = block_idx_x * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce; + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(row_length, kScaleBlockDim)); + tile_scales_inv_c[offset] = scale_inv; + } else { + tile_scales_inv_c[row_idx * scale_stride_y + col_idx * scale_stride_x] = scale_inv; + } + } + // Step 2.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNVecSMem; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = ComputeOutputFP4(smem_vec[i].data.elt[0], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[i].data.elt[1], encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[i + 1].data.elt[0], encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[i + 1].data.elt[1], encode_scale); + const uint32_t rbits = kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 2.7: Store output_c + if constexpr (kAligned) { + output_vec.store_to(output_g); + } else { + if (r_g < num_rows) { + output_vec.store_to_elts(output_g, 0, num_ele); + } + } + // Step 2.8: Update output address, row index of shared memory (and row index of global memory + // for not aligned case) + output_g += stride_g / kNFP4PerContainer; + r_s += r_stride; + if constexpr (!kAligned) { + r_g += r_stride; + } + } + } + + // Step 3: Transpose, cast and store to output_t + if constexpr (kReturnTranspose) { + constexpr int c_stride = + kThreadsPerBlock / kNumThreadsStore; // Stride in columns of shared memory + constexpr int num_iterations = kTileDim / (c_stride * kNVecSMem); + const int r_s = (threadIdx.x % kNumThreadsStore) * kNVecOut; // Row in shared memory + int c_s = threadIdx.x / kNumThreadsStore; // Column in shared memory + size_t r_g = block_idx_x * kTileDim + c_s * kNVecSMem; // Row in global memory + const size_t c_g = block_idx_y * kTileDim + r_s; // Column in global memory + const size_t stride_g = + static_cast(c_stride) * kNVecSMem * num_rows; // Stride in global memory + const size_t num_ele = (c_g < num_rows ? min(static_cast(kNVecOut / kNFP4PerContainer), + (num_rows - c_g) / kNFP4PerContainer) + : 0); // For not aligned case + OType* output_g = + &output_t[(r_g * num_rows + c_g) / kNFP4PerContainer]; // Output address in global memory + // Each kNumThreadsStore threads form a warp process one row, we need to find the lane id of + // the first thread to do the reduction. + const unsigned src_lane = + (threadIdx.x % kThreadsPerWarp) / kNumThreadsReduce * kNumThreadsReduce; + // This mask represents which threads should do the reduction together. + const unsigned mask = ((1 << kNumThreadsReduce) - 1) << src_lane; + const bool is_src_lane = (threadIdx.x % kNumThreadsReduce) == 0; +#pragma unroll + for (int iter = 0; iter < num_iterations; ++iter) { + SMemVec smem_vec[kNVecOut]; + // Step 3.1: Load from shared memory to registers +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + int r = r_s + i; + int c = c_s; + smem_vec[i] = smem[r * kSMemCol + c]; + } +#pragma unroll + for (int smem_idx = 0; smem_idx < kNVecSMem; ++smem_idx) { + // Step 3.2: Compute local amax + CType amax = 0; + if constexpr (kIs2DBlockScaling) { + // TODO(zhongbo): 2D block scaling, directly read from amax_smem + int warp_idx = threadIdx.x / kThreadsPerWarp; // 0 ~ 7 + constexpr int kNumColsPerWarp = + kThreadsPerWarp / kNumThreadsStore * kNVecSMem; // 8 elements + constexpr int kNumWarpsPerBlock = + kThreadsPerBlock / kThreadsPerWarp; // 8 warps per block + constexpr int kNumColsPerIter = kNumColsPerWarp * kNumWarpsPerBlock; + int tid_in_warp_x = (threadIdx.x / kNumThreadsStore) % kNumColsPerWarp; + int tid_in_warp_y = (threadIdx.x % kThreadsPerWarp) % kNumThreadsStore; + int data_col_idx = iter * kNumColsPerIter + warp_idx * kNumColsPerWarp + tid_in_warp_x; + amax = amax_smem[tid_in_warp_y][data_col_idx / kFP4BlockScalingSize]; + } else { +#pragma unroll + for (int i = 0; i < kNVecOut; ++i) { + amax = fmaxf(amax, fabsf(smem_vec[i].data.elt[smem_idx])); + } + } + // Step 3.3: Reduce amax + if constexpr (kIsE8Scaling) { +#pragma unroll + for (int delta = kNumThreadsReduce / 2; delta > 0; delta /= 2) { + const float other_amax = __shfl_down_sync(mask, amax, delta); + __builtin_assume(amax >= 0); + __builtin_assume(other_amax >= 0); + amax = fmaxf(amax, other_amax); + } + amax = __shfl_sync(mask, amax, src_lane); + } + // Step 3.4: Compute scale + ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + // Step 3.5: Write scale_inv_t + bool write_scale_inv = is_src_lane; + if constexpr (!kAligned) { + write_scale_inv &= (r_g + smem_idx < row_length); + write_scale_inv &= (c_g < num_rows); + } + if (write_scale_inv) { + size_t row_idx = block_idx_x * kTileDim + c_s * kNVecSMem + smem_idx; + size_t col_idx = (block_idx_y * (kNumThreadsStore / kNumThreadsReduce) + + (threadIdx.x % kNumThreadsStore) / kNumThreadsReduce); + if constexpr (kSwizzledScale) { + size_t offset = scale_factor_swizzled_offset( + row_idx, col_idx, DIVUP(num_rows, kScaleBlockDim)); + tile_scales_inv_t[offset] = scale_inv; + } else { + tile_scales_inv_t[row_idx * scale_t_stride_y + col_idx * scale_t_stride_x] = scale_inv; + } + } + // Step 3.6: Quantize + OVec output_vec; +#pragma unroll + for (int i = 0; i < kNVecOut / kNFP4PerContainer; i += 2) { + // Pack two elements into __nv_bfloat162 + float2 f2_a; + float2 f2_b; + f2_a.x = + ComputeOutputFP4(smem_vec[2 * i].data.elt[smem_idx], encode_scale); + f2_a.y = ComputeOutputFP4(smem_vec[2 * i + 1].data.elt[smem_idx], + encode_scale); + f2_b.x = ComputeOutputFP4(smem_vec[2 * (i + 1)].data.elt[smem_idx], + encode_scale); + f2_b.y = ComputeOutputFP4(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx], + encode_scale); + const uint32_t rbits = + kApplyStochasticRounding ? get_rbits(rng, random_uint4, rnd_idx) : 0; + // Convert to __nv_fp4x4_e2m1 + __nv_fp4x4_e2m1 out_4x = cvt_fp32_to_fp4_4x(f2_a, f2_b, rbits); + + output_vec.data.elt[i] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[0]; + output_vec.data.elt[i + 1] = reinterpret_cast<__nv_fp4x2_storage_t*>(&out_4x)[1]; + } + // Step 3.7: Store output_t + if constexpr (kAligned) { + output_vec.store_to(output_g + smem_idx * num_rows / kNFP4PerContainer); + } else { + if (r_g + smem_idx < row_length) { + output_vec.store_to_elts(output_g + smem_idx * num_rows / kNFP4PerContainer, 0, + num_ele); + } + } + } + // Step 3.8: Update output address, column index of shared memory (and row index of global + // memory for not aligned case) + output_g += stride_g / kNFP4PerContainer; + c_s += c_stride; + if constexpr (!kAligned) { + r_g += c_stride * kNVecSMem; + } + } + } +} + +} // namespace +} // namespace quantize_transpose_nvfp4 +#endif // CUDA_VERSION >= 12080 + +namespace detail { + +void quantize_transpose_vector_blockwise_fp4( + const SimpleTensor& input, const SimpleTensor& global_amax, SimpleTensor& scale_inv, + SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, + const bool return_identity, const bool return_transpose, const bool pow2_scale, + const bool swizzled_scale, const bool use_stochastic_rounding, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const SimpleTensor& noop_tensor, cudaStream_t stream) { + NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); +#if CUDA_VERSION >= 12080 + + // pow 2 scale is for MXFP4 since it's using E8M0 scaling + // raise error if pow2_scale is true + NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now"); + + NVTE_CHECK(return_identity || return_transpose, + "At least one of return_identity or return_transpose must be true."); + + NVTE_CHECK(return_identity || !use_2d_quantization, + "2D block quantization is only supported when return_identity is true."); + + const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; + size_t num_elements = row_length; + size_t num_rows = 1; + for (size_t i = 0; (i < input.shape.size() - 1) && (input.shape.size() > 0); ++i) { + num_rows *= input.shape.at(i); + num_elements *= input.shape.at(i); + } + + // Early return if the input tensor is empty + if (num_elements == 0) { + return; + } + + size_t scale_stride_x = 0; + size_t scale_stride_y = 0; + + if (return_identity) { + scale_stride_x = 1; + scale_stride_y = scale_inv.shape[1]; + } + + size_t scale_t_stride_x = 0; + size_t scale_t_stride_y = 0; + + if (return_transpose) { + scale_t_stride_x = 1; + scale_t_stride_y = scale_inv_t.shape[1]; + } + + using namespace transformer_engine::quantize_transpose_nvfp4; + + const size_t num_blocks_x = DIVUP(row_length, static_cast(kTileDim)); + const size_t num_blocks_y = DIVUP(num_rows, static_cast(kTileDim)); + + // noop tensor for cuda graph + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + + const size_t* rng_state = nullptr; + if (rng_state_tensor != nullptr) { + Tensor& rng_state_te_tensor = *convertNVTETensor(rng_state_tensor); + NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64, + "RNG state should contain 2 64-bit values."); + NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector{2}, + "Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape); + rng_state = reinterpret_cast(rng_state_te_tensor.data.dptr); + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype, InputType, + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY( + return_identity ? output.dtype : output_t.dtype, 2, OutputType, + + dim3 grid(num_blocks_x, num_blocks_y, 1); + + using ScaleType = fp8e4m3; constexpr int kScaleBlockDim = 16; + constexpr bool kPow2Scale = false; + + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_identity, kReturnIdentity, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + return_transpose, kReturnTranspose, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + full_tile, kAligned, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + swizzled_scale, kSwizzledScale, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_stochastic_rounding, kApplyStochasticRounding, + + TRANSFORMER_ENGINE_SWITCH_CONDITION( + use_2d_quantization, kIs2DBlockScaling, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, + num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, + scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, + noop_ptr);) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // CUDA_VERSION >= 12080 +} + +} // namespace detail +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/tensor/_internal/__init__.py b/transformer_engine/common/triton/__init__.py similarity index 69% rename from transformer_engine/pytorch/tensor/_internal/__init__.py rename to transformer_engine/common/triton/__init__.py index e13014bf7..76c9b98d0 100644 --- a/transformer_engine/pytorch/tensor/_internal/__init__.py +++ b/transformer_engine/common/triton/__init__.py @@ -1,4 +1,5 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""Internal data structures for quantized tensors.""" + +"""Kernels written with OpenAI Triton.""" diff --git a/transformer_engine/common/triton/cross_entropy.py b/transformer_engine/common/triton/cross_entropy.py new file mode 100644 index 000000000..fc49ac20b --- /dev/null +++ b/transformer_engine/common/triton/cross_entropy.py @@ -0,0 +1,252 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient Cross Entropy kernels written with OpenAI Triton.""" + +import triton +import triton.language as tl + + +@triton.jit +def online_softmax_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes the m/d components on this TP rank for the online softmax. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride (int): The stride of the m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) + else: + X_y = float("-inf") + else: + X_y = float("-inf") + + m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 + + # 3. [Online softmax] first pass: find max + sum + m = float("-inf") # m is the max value. use the notation from the paper + d = 0.0 # d is the sum. use the notation from the paper + + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( + tl.float32 + ) + block_max = tl.max(X_block) + m_new = tl.maximum(m, block_max) + d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) + m = m_new + + tl.store(m_d_X_y_ptr, m) + tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) + tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) + + +@triton.jit +def cross_entropy_kernel( + X_ptr, + X_stride, + Y_ptr, + Y_stride, + loss_ptr, + loss_stride, + m_d_X_y_ptr, + m_d_X_y_stride, + rank, + world_size, + ignore_idx, + n_cols, + n_non_ignore, + reduce_loss: tl.constexpr, + label_smoothing: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """ + This kernel computes both cross entropy loss and the gradient of the input. + + Parameters: + X_ptr: Pointer to input tensor. + X_stride (int): The stride of the input tensor. + Y_ptr: Pointer to target tensor. + Y_stride (int): The stride of the target tensor. + loss_ptr: Pointer to tensor to store the loss. + loss_stride (int): The stride of the loss tensor. + m_d_X_y_ptr: Pointer to m/d/X_y tensor. + m_d_X_y_stride: The stride of m/d/X_y tensor. + rank (int): The rank of this device in the TP group. + world_size (int): The size of world involved in this distributed loss calculation. + ignore_idx (int): Tokens to be ignored for loss and gradient calculation. + n_cols (int): The number of columns in the input tensor. + n_non_ignore (int): The number of non-ignored elements in the batch. + label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + program_id = tl.program_id(0).to(tl.int64) + + # locate the start index + X_ptr += program_id * X_stride + + # Load Y_ptr + Y_ptr += program_id * Y_stride + y = tl.load(Y_ptr) + + if y == ignore_idx: + # set all X_ptr as 0 + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) + return + + loss_ptr += program_id * loss_stride + m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride + + # Need to reduce the m/d/X_y values from other TP ranks + m = tl.load(m_d_X_y_ptr) + d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) + ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) + + for i in range(1, world_size): + offset = i * 3 * n_non_ignore * m_d_X_y_stride + access_ptr = m_d_X_y_ptr + offset + m_new = tl.load(access_ptr) + d_new = tl.load(access_ptr + m_d_X_y_stride) + X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) + + d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) + m = tl.maximum(m, m_new) + ori_X_y = tl.maximum(ori_X_y, X_y_new) + + # Label smoothing is a general case of normal cross entropy + scaled_x_sum = 0.0 + eps = label_smoothing / (n_cols * world_size) + + # 4. [Online softmax] second pass: calculate the gradients + # dx_y = (softmax(x_y) - 1) / N + # dx_i = softmax(x_i) / N, i != y + # N is the number of non ignored elements in the batch + # For label smoothing: + # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y + # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N + # = dx_i - (1 - label_smoothing) / N + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) + grad_dtype = X_block.dtype + X_block = X_block.to(tl.float32) + if label_smoothing > 0: + # scale X beforehand to avoid overflow + scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) + # Scale gradients based on reduction mode + # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore + # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here + if reduce_loss: + X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) + else: + X_block = tl.exp(X_block - m) / d - eps + tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) + + # We need tl.debug_barrier() to ensure the new result of X_ptr is written + tl.debug_barrier() + + # 5. Calculate the loss + + # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) + # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) + loss = -(ori_X_y - m - tl.log(d)) + + # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps + # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) + # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) + # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: + # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) + # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 + if label_smoothing > 0: + smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) + loss = loss * (1 - label_smoothing) + smooth_loss + + # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` + vocab_start_idx = rank * n_cols + vocab_end_idx = (rank + 1) * n_cols + if y >= vocab_start_idx: + if y < vocab_end_idx: + X_y = tl.load(X_ptr + y - vocab_start_idx) + # Apply the same conditional scaling logic for the target token + if reduce_loss: + X_y += -(1 - label_smoothing) / (n_non_ignore) + else: + X_y += -(1 - label_smoothing) + tl.store(X_ptr + y - vocab_start_idx, X_y) + + tl.store(loss_ptr, loss) + + +@triton.jit +def element_mul_kernel( + X_ptr, + X_stride, + grad_output_ptr, + grad_output_stride, + n_cols, + BLOCK_SIZE: tl.constexpr, +): + """ + This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. + The multiplication is performed in-place on the tensor pointed by X_ptr. + + Parameters: + X_ptr: Pointer to the input tensor. + X_stride (int): The stride of the input tensor. + grad_output_ptr: Pointer to the gradient output value. + n_cols (int): The number of columns in the input tensor. + BLOCK_SIZE (int): The block size for Triton operations. + """ + + # Get the program ID and convert it to int64 to avoid overflow + program_id = tl.program_id(0).to(tl.int64) + + # Locate the start index + X_ptr += program_id * X_stride + + # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride + grad_output = tl.load(grad_output_ptr) + + # Perform the element-wise multiplication + for i in range(0, n_cols, BLOCK_SIZE): + X_offsets = i + tl.arange(0, BLOCK_SIZE) + X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) + tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) diff --git a/transformer_engine/common/triton/pad.py b/transformer_engine/common/triton/pad.py new file mode 100644 index 000000000..8f15e7dcb --- /dev/null +++ b/transformer_engine/common/triton/pad.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient NVFP4 padding kernels written with OpenAI Triton . + +TODO(ksivamani): Documentation + +""" + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=4, num_stages=2), + triton.Config({"BLOCK_M": 256, "BLOCK_N": 128}, num_warps=8, num_stages=2), + triton.Config({"BLOCK_M": 128, "BLOCK_N": 256}, num_warps=8, num_stages=1), + ], + key=["out_dim0", "out_dim1"], +) +@triton.jit +def zero_pad_kernel( + inp_ptr, + out_ptr, + in_dim0: tl.constexpr, + in_dim1: tl.constexpr, + out_dim0: tl.constexpr, + out_dim1: tl.constexpr, + in_s0, + in_s1, + out_s0, + out_s1, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + # tile over OUTPUT coordinates + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # output rows + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # output cols + om = offs_m[:, None] + on = offs_n[None, :] + + # edge masking for output + out_mask = (om < out_dim0) & (on < out_dim1) + + # valid input region is simply top-left (no offsets) + in_mask = (om < in_dim0) & (on < in_dim1) + + # load valid input, else zero (masked load touches memory only where True) + x = tl.load(inp_ptr + om * in_s0 + on * in_s1, mask=in_mask, other=0) + + # store to output (only within bounds of the output tile) + tl.store(out_ptr + om * out_s0 + on * out_s1, x, mask=out_mask) diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py new file mode 100644 index 000000000..3a3a32014 --- /dev/null +++ b/transformer_engine/common/triton/permutation.py @@ -0,0 +1,605 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Efficient Permutation kernels written with OpenAI Triton.""" + +import triton +import triton.language as tl + +from triton.language import core +from triton.language.standard import _log2 +from packaging import version + + +# The following three argsort related kernels are adapted from +# the issue https://github.com/triton-lang/triton/issues/3698 + +get_int_dtype = core.get_int_dtype +if version.parse(triton.__version__) >= version.parse("3.5.0"): + get_int_dtype = triton.constexpr_function(get_int_dtype) + + +@triton.jit +def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)] + y = tl.reshape(x, shape) + z = tl.reshape(indices, shape) + + mask = tl.arange(0, 2)[None, :, None] + + l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to( + x.dtype + ) + + l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) + r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) + + idtype = get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) + + il_value = l_value.to(idtype, bitcast=True) + ir_value = r_value.to(idtype, bitcast=True) + ix = x.to(idtype, bitcast=True) + + flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix)) + ret = ix ^ flag1 + flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix)) + ind = indices ^ flag2 + + return ret.to(x.dtype, bitcast=True), ind + + +@triton.jit +def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): + n_outer: tl.constexpr = x.numel >> n_dims + tl.static_assert(stage <= n_dims) + """ + order_type 0 == ascending + order_type 1 == descending + order_type 2 == alternating + """ + if order == 2: + shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage] + flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) + else: + flip = tl.full(x.shape, value=order, dtype=tl.int32) + for i in tl.static_range(stage): + x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims) + return x, indices + + +@triton.jit +def _argsort(x, indices, n_dims: tl.constexpr): + for i in tl.static_range(1, n_dims + 1): + x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims) + return x, indices + + +@triton.jit +def _row_id_map_pass_1_kernel( + # pointers + routing_map_ptr, + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # strides + stride_routing_map_token, + stride_routing_map_expert, + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + expert_token_mask = tl.load( + routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, + mask=(offset < num_tokens), + other=0, + ).to(tl.int32) + row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask + tl.store( + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + row_id_within_token_block, + mask=offset < num_tokens, + ) + n_tokens_per_block = tl.sum(expert_token_mask) + tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block) + + +@triton.jit +def _row_id_map_pass_2_kernel( + # pointers + row_id_map_ptr, + workspace_ptr, + # sizes + num_tokens, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + WORKSPACE_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_m = tl.program_id(0) + pid_n = tl.program_id(1) + chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n + offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + row_id_within_token_block = tl.load( + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + mask=(offset < num_tokens), + other=0, + ) + + workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) + n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) + row_id = tl.where( + row_id_within_token_block == 0, + -1, + row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, + ) + tl.store( + row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, + row_id, + mask=(offset < num_tokens), + ) + + +@triton.jit +def _row_id_map_pass_3_kernel( + # pointers + row_id_map_ptr, + # sizes + num_experts: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + # metas + LOAD_SIZE: tl.constexpr, +): + pid = tl.program_id(0) + n_dims: tl.constexpr = _log2(LOAD_SIZE) + off = tl.arange(0, LOAD_SIZE) + row_id_map = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off, + mask=off < num_experts, + other=-1, + ) + n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0)) + indices = off + sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert, + sorted_map, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + off) * stride_row_id_map_expert, + indices, + mask=off < n_routed, + ) + tl.store( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert, + n_routed, + ) + + +@triton.jit +def _permute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + scale_ptr, + permuted_probs_ptr, + permuted_scale_ptr, + # sizes + num_experts: tl.constexpr, + hidden_size: tl.constexpr, + scale_hidden_dim, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_probs_expert, + stride_scale_token, + stride_scale_hidden, + stride_permuted_probs_token, + stride_permuted_scale_token, + stride_permuted_scale_hidden, + # metas + PERMUTE_PROBS: tl.constexpr, + PERMUTE_SCALE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = cur_off < hidden_size + src_row = pid_t.to(tl.int64) + input_off = src_row * stride_input_token + cur_off * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + if PERMUTE_SCALE: + mask_scale = cur_off < scale_hidden_dim + scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden + scale = tl.load(scale_ptr + scale_off, mask=mask_scale) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ).to(tl.int64) + output_off = dst_row * stride_output_token + cur_off * stride_output_hidden + if PERMUTE_SCALE: + permuted_scale_off = ( + dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden + ) + tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) + if PERMUTE_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert + prob = tl.load(probs_ptr + prob_off) + if pid_h == 0: + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + if prob == 0.0: + # for routing_map padding + # dst_row != -1 and prob == 0.0 means that this slot is padded + tl.store(output_ptr + output_off, 0.0, mask=mask) + else: + tl.store(output_ptr + output_off, inp, mask=mask) + else: + tl.store(output_ptr + output_off, inp, mask=mask) + + +try: + _permute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_permute_kernel) +except RuntimeError: + pass + + +@triton.jit +def _unpermute_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + merging_probs_ptr, + permuted_probs_ptr, + unpermuted_probs_ptr, + # sizes + num_experts: tl.constexpr, + hidden_size: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_permuted_probs_token, + stride_unpermuted_probs_token, + stride_unpermuted_probs_expert, + # metas + PROBS_LOAD_WIDTH: tl.constexpr, + WITH_MERGING_PROBS: tl.constexpr, + PERMUTE_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + data_type = input_ptr.dtype.element_ty + compute_type = tl.float32 + + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + if PERMUTE_PROBS: + # write 0.0 to probs_grad that are not routed + if pid_h == 0: + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + stride_unpermuted_probs_expert * map_load_off + ) + tl.store( + unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts + ) + accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + n_routed = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + src_row = tl.load( + row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert + ).to(tl.int64) + input_off = src_row * stride_input_token + current_offset * stride_input_hidden + inp = tl.load(input_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + if WITH_MERGING_PROBS: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + merging_prob_off = ( + pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + inp *= merging_prob + accumulator += inp + if PERMUTE_PROBS: + if pid_h == 0: + expert_idx = tl.load( + row_id_map_ptr + + pid_t * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + unpermuted_prob_off = ( + pid_t * stride_unpermuted_probs_token + + expert_idx * stride_unpermuted_probs_expert + ) + permuted_prob_off = src_row * stride_permuted_probs_token + prob = tl.load(permuted_probs_ptr + permuted_prob_off) + tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) + accumulator = accumulator.to(data_type) + dst_row = pid_t.to(tl.int64) + output_off = dst_row * stride_output_token + current_offset * stride_output_hidden + tl.store(output_ptr + output_off, accumulator, mask=mask) + + +try: + _unpermute_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_unpermute_kernel) +except RuntimeError: + pass + + +@triton.jit +def _unpermute_bwd_with_merging_probs_kernel( + # pointers + fwd_output_grad_ptr, + fwd_input_grad_ptr, + fwd_input_ptr, + merging_probs_ptr, + merging_probs_grad_ptr, + row_id_map_ptr, + # sizes + num_experts: tl.constexpr, + hidden_size: tl.constexpr, + # strides + stride_row_id_map_token, + stride_row_id_map_expert, + stride_fwd_output_grad_token, + stride_fwd_output_grad_hidden, + stride_fwd_input_grad_token, + stride_fwd_input_grad_hidden, + stride_fwd_input_token, + stride_fwd_input_hidden, + stride_merging_probs_token, + stride_merging_probs_expert, + stride_merging_probs_grad_token, + stride_merging_probs_grad_expert, + # metas + PROBS_LOAD_WIDTH: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + data_type = fwd_output_grad_ptr.dtype.element_ty + compute_type = tl.float32 + + pid = tl.program_id(0) + map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) + token_probs_grad_off = ( + pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off + ) + tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts) + n_routed = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert + ) + for idx in tl.range(n_routed): + dst_row = tl.load( + row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert + ).to(tl.int64) + expert_idx = tl.load( + row_id_map_ptr + + pid * stride_row_id_map_token + + (num_experts + idx) * stride_row_id_map_expert + ) + prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) + current_start = 0 + while current_start < hidden_size: + current_offset = current_start + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + src_row = pid.to(tl.int64) + input_off = ( + src_row * stride_fwd_output_grad_token + + current_offset * stride_fwd_output_grad_hidden + ) + inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) + inp = inp.to(compute_type) + merging_prob_off = ( + pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert + ) + merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) + output = inp * merging_prob + output = output.to(data_type) + output_off = ( + dst_row * stride_fwd_input_grad_token + + current_offset * stride_fwd_input_grad_hidden + ) + tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) + + fwd_input_off = ( + dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden + ) + fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) + prob_grad_accum += fwd_input.to(compute_type) * inp + current_start += BLOCK_SIZE + probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) + probs_grad_off = ( + pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert + ) + tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) + + +try: + _unpermute_bwd_with_merging_probs_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_unpermute_bwd_with_merging_probs_kernel) +except RuntimeError: + pass + + +@triton.jit +def _make_chunk_sort_map_kernel( + # pointers + split_sizes_ptr, + sorted_indices_ptr, + dst_rows_ptr, + # sizes + num_splits: tl.constexpr, + # metas + IDX_LOAD_WIDTH: tl.constexpr, +): + pid = tl.program_id(0) + + load_split_offset = tl.arange(0, IDX_LOAD_WIDTH) + sorted_indices = tl.load( + sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits + ) + + # get chunk idx of the current token in the input tensor + input_split_sizes = tl.load( + split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 + ).to(tl.int32) + input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) + input_chunk_idx = tl.sum(input_split_sizes_mask) + input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) + in_chunk_offset = pid - input_split_sizes_presum + + # get chunk idx of the current token in the output tensor + output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0) + output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1) + + # make row_id_map + output_split_sizes = tl.load( + split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits + ).to(tl.int32) + output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) + dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + tl.store(dst_rows_ptr + pid, dst_row) + + +@triton.jit +def _sort_chunks_by_map_kernel( + # pointers + input_ptr, + output_ptr, + row_id_map_ptr, + probs_ptr, + permuted_probs_ptr, + # sizes + hidden_size: tl.constexpr, + # strides + stride_input_token, + stride_input_hidden, + stride_output_token, + stride_output_hidden, + stride_probs_token, + stride_permuted_probs_token, + # metas + PERMUTE_PROBS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + FORWARD: tl.constexpr, +): + pid_t = tl.program_id(0) + pid_h = tl.program_id(1) + if FORWARD: + src_row = pid_t.to(tl.int64) + dst_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) + else: + src_row = tl.load(row_id_map_ptr + pid_t).to(tl.int64) + dst_row = pid_t.to(tl.int64) + current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = current_offset < hidden_size + input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden + output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden + inp = tl.load(input_ptr + input_offsets, mask=mask) + tl.store(output_ptr + output_offsets, inp, mask=mask) + if PERMUTE_PROBS: + if pid_h == 0: + prob_off = src_row * stride_probs_token + prob = tl.load(probs_ptr + prob_off) + permuted_prob_off = dst_row * stride_permuted_probs_token + tl.store(permuted_probs_ptr + permuted_prob_off, prob) + + +try: + _sort_chunks_by_map_kernel = triton.autotune( + configs=[ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + triton.Config({"BLOCK_SIZE": 4096}), + ], + key=["hidden_size"], + )(_sort_chunks_by_map_kernel) +except RuntimeError: + pass diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh deleted file mode 100644 index b7c4cf837..000000000 --- a/transformer_engine/common/util/cast_kernels.cuh +++ /dev/null @@ -1,1546 +0,0 @@ -/************************************************************************* - * This file was modified for portability to AMDGPU - * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. - * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file cast_kernels.cuh - * \brief CUDA kernels to cast to/from FP8/MXFP8. - */ - -#ifndef TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ -#define TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ - -#include -#ifndef __HIP_PLATFORM_AMD__ -#include -#endif //#ifndef __HIP_PLATFORM_AMD__ -#include -#include - -#include - -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" -#include "math.h" -#include "ptx.cuh" -#include "transformer_engine/transformer_engine.h" -#ifdef __HIP_PLATFORM_AMD__ -#include "rocm_cast_kernels.cuh" -#endif - -namespace transformer_engine { - -#ifndef __HIP_PLATFORM_AMD__ -namespace mxfp8_kernel { - -constexpr size_t SCALE_DIM_Y = 32; -constexpr size_t SCALE_DIM_X = 32; - -constexpr size_t BUFFS_NUM = 2; -constexpr size_t PACK_SIZE = 4; -constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; - -// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 - -// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory -constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 - -template -__global__ void __launch_bounds__(THREADS_PER_CHUNK) - cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output_rowwise, - const __grid_constant__ CUtensorMap tensor_map_output_colwise, - e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, - const float *noop, float *const dbias_workspace, float *const amax_ptr, - const size_t rows, const size_t cols, const size_t scale_stride_rowwise, - const size_t scale_stride_colwise) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; - constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; - - using IType2 = typename ptx::FPx2; - using OType2 = typename ptx::FPx2; - - if constexpr (NO_ACTIVATIONS) { - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - } - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; - constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; - static_assert(BUFF_DIM_Y == 32); - - constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; - static_assert(STAGES >= 1); - - constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - - const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; - const size_t tid_X_rowwise = threadIdx.x % THREADS_X; - const size_t tid_Y_colwise = 0; - const size_t tid_X_colwise = threadIdx.x; - - const size_t thread_offset_Y_rowwise = tid_Y_rowwise; - const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const size_t thread_offset_Y_colwise = tid_Y_colwise; - const size_t thread_offset_X_colwise = tid_X_colwise; - - const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; - - const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - - const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - - // helps resolving bank conflicts in shmem - const int thread_lane = threadIdx.x % THREADS_PER_WARP; - const int bank_group = thread_lane / THREADS_PER_BANK; - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); - - extern __shared__ char dynamic_shmem[]; - uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); - // Manually align dynamic SHMEM per TMA requirements using padding - // __align__(128) Does not guarantee the pointer to be aligned! - uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned - IType *in_sh = reinterpret_cast(dshmem); - IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); - OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); - OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); - IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - - constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float partial_dbias_colwise = 0.0f; - float thread_dbias_rowwise[SCALE_DIM_X]; - if constexpr (IS_DBIAS) { -#pragma unroll - for (int j = 0; j < SCALE_DIM_X; ++j) { - thread_dbias_rowwise[j] = 0.0f; - } - } - - float block_amax = 0.0f; - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[STAGES]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], - &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, - &mbar[0], is_master_thread); - } - -#pragma unroll - for (int stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_offset_Y = stage * BUFF_DIM_Y; - - if (next_stage < STAGES) { - // Wait for TMA transfer to have finished reading shared memory. - // I.e. the buffer is ready to be written to - ptx::cp_async_bulk_wait_group_read<1>(); - - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t next_buff_offset = next_buff * BUFF_DIM; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, - global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, - global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); - } - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[stage], parity); - - float thread_amax = 0.0f; - if constexpr (COLWISE_SCALING) { - const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; - thread_amax = 0.0f; - float in_compute_colwise[BUFF_DIM_Y]; - IType in_colwise_IType[BUFF_DIM_Y]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType thread_amax_f16 = static_cast(0.0f); -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - in_colwise_IType[i] = in_sh[shmem_offset_colwise]; - thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); - } - thread_amax = static_cast(thread_amax_f16); - } else { -#pragma unroll - for (int i = 0; i < BUFF_DIM_Y; ++i) { - const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - - float elt = static_cast(in_sh[shmem_offset_colwise]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - partial_dbias_colwise += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - // Cache computed activations to avoid computing them again in the 2nd pass along another dimension - if constexpr (IS_CACHED_ACT_OP) { - cached_act_sh[shmem_offset_colwise] = static_cast(elt); - } - - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); - const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_colwise[i] = elt; - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - - const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; - const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - -// 3. Scale elements -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - float in; - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = static_cast(in_colwise_IType[i]); - } else { - in = in_compute_colwise[i]; - } - const float scaled_out = in * block_scale_inverse; - - const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; - out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); - } - } - - if constexpr (ROWWISE_SCALING) { - const size_t shmem_offset_base_rowwise = - buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; - thread_amax = 0.0f; - float in_compute_rowwise[SCALE_DIM_X]; - Vec in_cached[WAVES]; - - // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY - Vec in_IType[WAVES]; - - // 1. Read/Compute elements. Find MXFP8-block AMAX - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - // Load elements - in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); - } - } - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } else if constexpr (IS_CACHED_ACT_OP) { - // ensures that all writes to cache made in the section above are visible to all threads - __syncthreads(); - IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - - // Load cached elements - in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); - // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) - // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries - if (!out_of_bounds) { - if constexpr (std::is_same_v) { -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); - } - } else { -#pragma unroll - for (int e = 0; e < PACK_SIZE; e += 2) { - const IType2 in_cached_2x = {in_cached[w].data.elt[e], - in_cached[w].data.elt[e + 1]}; - ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); - } - } - } - } - if constexpr (!std::is_same_v) { - thread_amax = - static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); - } - } else { -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - - Vec in; - Vec act_in; - - in.load_from(&in_sh[shmem_offset_rowwise]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[shmem_offset_rowwise]); - } -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - // Compute element - float elt = static_cast(in.data.elt[e]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[e]); - elt *= OP(act_in_elt, {}); - } - - // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again - if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { - thread_dbias_rowwise[j] += elt; - } - // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 - if constexpr (!std::is_same_v) { - elt = static_cast(static_cast(elt)); - } - if constexpr (COMPUTE_ACTIVATIONS) { - const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); - const bool swizzled_col_out_of_bounds = - (block_offset_X + swizzled_thread_idx >= cols); - const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); - if (!out_of_bounds) { - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - thread_amax = fmaxf(thread_amax, fabsf(elt)); - } - in_compute_rowwise[j] = elt; - } - } - } - - // 2. Compute E8M0 scaling factor - const e8m0_t biased_exponent = - ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const size_t stage_scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - - const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); - const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - - // 3. Scale elements -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - Vec out; -#pragma unroll - for (int e = 0; e < PACK_SIZE / 2; ++e) { - IType2 in; - OType2 &out_pair = reinterpret_cast(out.data.elt[e]); - if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { - in = in_IType[w].data.elt[e]; - } else if constexpr (IS_CACHED_ACT_OP) { - in.x = in_cached[w].data.elt[2 * e]; - in.y = in_cached[w].data.elt[2 * e + 1]; - } else { - const int j = w * PACK_SIZE + 2 * e; - in.x = in_compute_rowwise[j]; - in.y = in_compute_rowwise[j + 1]; - } - ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); - } - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; - out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); - } - } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t global_offset_Y = block_offset_Y + stage_offset_Y; - const size_t global_offset_X = block_offset_X; - const size_t buff_offset = buff * BUFF_DIM; - - if constexpr (ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); - } - if constexpr (COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), global_offset_X, - global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); - } - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - } - } - - parity ^= 1; - - if constexpr (IS_DBIAS) { - float thread_partial_dbias = 0.0f; - if constexpr (COLWISE_SCALING) { - thread_partial_dbias = partial_dbias_colwise; - } else { - // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] - // HEIGHT = THREADS_Y - // WIDTH = THREADS_X * (SCALE_DIM_X + 1) - // Added extra 1-element padding per thread_X to reduce bank conflicts - float *partial_dbias_rowwise = reinterpret_cast(dshmem); - - constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - - const size_t shmem_thread_offset = - tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); -#pragma unroll - for (int w = 0; w < WAVES; ++w) { - const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; -#pragma unroll - for (int e = 0; e < PACK_SIZE; ++e) { - const int j = w * PACK_SIZE + e; - const size_t shmem_elt_idx = swizzled_group_offset + e; - partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; - } - } - __syncthreads(); -#pragma unroll - for (int i = 0; i < THREADS_Y; ++i) { - // Add extra element offset per MXFP8 scaling block [1x32] - const size_t scaling_block = threadIdx.x / SCALE_DIM_X; - thread_partial_dbias += - partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; - } - } - const size_t dbias_stride = cols; - const size_t dbias_offset_Y = blockIdx.y; - const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; - const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); - if (!col_out_of_bounds_dbias) { - dbias_workspace[dbias_idx] = thread_partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); - } - - if (is_master_thread && amax_ptr != nullptr) { - atomicMaxFloat(amax_ptr, block_amax); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -} // namespace mxfp8_kernel - -constexpr size_t FP8_CHUNK_DIM_Y = 128; -constexpr size_t FP8_CHUNK_DIM_X = 128; -constexpr size_t FP8_THREADS_PER_CHUNK = 128; -constexpr size_t FP8_BUFFERS_NUM = 2; -constexpr size_t FP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(FP8_PREFETCH_BUFFERS_NUM < FP8_BUFFERS_NUM); - -constexpr size_t FP8_BUFFER_DIM_Y = 16; -constexpr size_t FP8_BUFFER_DIM_X = FP8_CHUNK_DIM_X; // 128 -constexpr size_t FP8_SHMEM_DIM_Y = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_SHMEM_DIM_X = FP8_BUFFER_DIM_X; // 128 - -constexpr size_t FP8_BUFF_STAGES_NUM = FP8_BUFFER_DIM_Y; // 16 -constexpr size_t FP8_ITERATIONS = FP8_CHUNK_DIM_Y / FP8_BUFFER_DIM_Y; // 8 = 128 / 16 -static_assert(FP8_ITERATIONS >= FP8_PREFETCH_BUFFERS_NUM); - -template -__global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) - cast_fp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_act_input, - const __grid_constant__ CUtensorMap tensor_map_output, - float *const dbias_workspace, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t rows, - const size_t cols) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - - const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - - const size_t thread_offset_Y = tid_Y; - const size_t thread_offset_X = tid_X; - - const size_t dbias_offset_Y = blockIdx.y + tid_Y; - const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; - const bool col_out_of_bounds = my_column >= cols; - const size_t dbias_stride = cols; - - float partial_dbias = 0.f; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) - OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - - constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[FP8_ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - const size_t chunk_offset_Y = block_offset_Y; - const size_t chunk_offset_X = block_offset_X; - -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const size_t chunk_stage_offset_X = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); - } - } - -#pragma unroll - for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const size_t buff = iter % FP8_BUFFERS_NUM; - const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; - if (next_iter < FP8_ITERATIONS) { - const size_t next_buff = next_iter % FP8_BUFFERS_NUM; - const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], - is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); - } - } - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const size_t stage_offset_Y = stage; - const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; - const size_t shmem_offset_x = thread_offset_X; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = row >= rows; - const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; - - float elt = static_cast(in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if constexpr (IS_DACT) { - if (!out_of_bounds) { - partial_dbias += elt; - } - } else { - // If no activation, elt is 0 so we can safely do this - partial_dbias += elt; - } - } - __builtin_assume(amax >= 0); - if (IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } - } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); - } - out_sh[buff][shmem_offset_y][shmem_offset_x] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const size_t chunk_it_offset_x = chunk_offset_X; - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - parity ^= 1; - - if constexpr (IS_DBIAS) { - const size_t dbias_offset_X = my_column; - const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias; - } - } - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -constexpr size_t CHUNKS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = FP8_THREADS_PER_CHUNK; -constexpr size_t CHUNK_SIZE = THREADS_PER_BLOCK; -constexpr size_t ELEMS_PER_BLOCK = CHUNKS_PER_BLOCK * CHUNK_SIZE; -constexpr size_t CHUNKS_PER_ITERATION = 32; -constexpr size_t SHMEM_DIM = CHUNKS_PER_ITERATION * CHUNK_SIZE; -constexpr size_t ITERATIONS = CHUNKS_PER_BLOCK / CHUNKS_PER_ITERATION; -constexpr size_t SHMEM_BUFFERS = 2; -static_assert(CHUNKS_PER_BLOCK % CHUNKS_PER_ITERATION == 0); - -template -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - cast_fp8_1D_kernel(const IType *input_ptr, OType *output_ptr, float *const amax_ptr, - float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - - const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; - const IType *input = input_ptr + block_offset; - OType *output = output_ptr + block_offset; - - float amax = 0; - const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - - // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - - constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; - - const bool is_master_thread = (threadIdx.x == 0); - -// Initialize shared memory barrier with the number of threads participating in the barrier. -#pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - initialize_barriers(mbar, is_master_thread); - - int parity = 0; - - copy_1d_to_shared(&(in_sh[0]), input, transaction_size_IN, &(mbar[0]), is_master_thread); - -#pragma unroll - for (int iter = 0; iter < ITERATIONS; ++iter) { - const size_t buff = iter % SHMEM_BUFFERS; - const size_t it_offset = iter * SHMEM_DIM; - - const size_t next_iter = iter + 1; - const size_t next_buff = next_iter % SHMEM_BUFFERS; - const size_t next_iter_offset = next_iter * SHMEM_DIM; - - if (next_iter < ITERATIONS) { - copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, - &(mbar[next_iter]), is_master_thread); - } - - ptx::fence_proxy_async_shared_cta(); - - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); - -#pragma unroll - for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; - float elt = static_cast(in_sh[buff][shmem_offset]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - __builtin_assume(amax >= 0); - amax = fmaxf(amax, fabsf(elt)); - out_sh[buff][shmem_offset] = static_cast(elt * scale); - } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - ptx::cp_async_bulk_tensor_1d_shared_to_global( - reinterpret_cast(output + it_offset), - reinterpret_cast(&out_sh[buff]), transaction_size_OUT); - - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read<1>(); - } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (amax_ptr != nullptr) { - const int warp_id = threadIdx.x / THREADS_PER_WARP; - // Reduce the amax over the block - amax = reduce_max(amax, warp_id); - // Update the global amax - if (is_master_thread) { - atomicMaxFloat(amax_ptr, amax); - } - } - - // Update scale-inverse - if (is_master_thread && blockIdx.x == 0 && (scale_inv_ptr != nullptr)) { - reciprocal(scale_inv_ptr, scale); - } - - destroy_barriers(mbar, is_master_thread); -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} -#endif // #ifndef __HIP_PLATFORM_AMD__ - -constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; -template -__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, - const size_t rows, const size_t cols) { - using ComputeVec = Vec; - using OutputVec = Vec; - - const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; - - if (thread_id * nvec >= cols) { - return; - } - - const float *const thread_in_base = dbias_partial + thread_id * nvec; - OType *const thread_out_base = dbias_output + thread_id * nvec; - - ComputeVec ldg_vec; - ComputeVec acc_vec; - acc_vec.clear(); - for (int i = 0; i < rows; ++i) { - ldg_vec.load_from(thread_in_base + i * cols); -#pragma unroll - for (int e = 0; e < nvec; ++e) { - acc_vec.data.elt[e] += ldg_vec.data.elt[e]; - } - } - - OutputVec stg_vec; -#pragma unroll - for (int e = 0; e < nvec; ++e) { - stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); - } - stg_vec.store_to(thread_out_base); -} - -template -void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, - cudaStream_t stream) { - constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 - constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); - - NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); - const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); - - reduce_dbias_kernel - <<>>( - reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -#ifndef __HIP_PLATFORM_AMD__ -template -static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { - const size_t N = product(input.data.shape); - - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - NVTE_CHECK(isFullTile, "Only full tiles are supported."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - const size_t chunks = DIVUP(N, CHUNK_SIZE); - const size_t blocks = DIVUP(chunks, CHUNKS_PER_BLOCK); - - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - const float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(THREADS_PER_BLOCK); - const dim3 grid(blocks); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - const IType *input_ptr = reinterpret_cast(input.data.dptr); - OType *output_ptr = reinterpret_cast(output->data.dptr); - - cast_fp8_1D_kernel<<>>( - input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) - ); // NOLINT(*) - NVTE_CHECK_CUDA(cudaGetLastError()); -} - -template -void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, Tensor *dbias, - Tensor *workspace, cudaStream_t stream) { - checkCuDriverContext(stream); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); - const size_t blocks_Y = chunks_Y; - const size_t blocks_X = chunks_X; - - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); - float *const scale_ptr = reinterpret_cast(output->scale.dptr); - - const dim3 block(FP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); - - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->data.dtype, OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.data.dtype)); - } - - create_2D_tensor_map(tensor_map_output, output->data, rows, cols, FP8_SHMEM_DIM_Y, - FP8_SHMEM_DIM_X, cols, 0, typeToNumBits(output->data.dtype)); - - cast_fp8_2D_kernel - <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, - workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols); - NVTE_CHECK_CUDA(cudaGetLastError()); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} -#endif // #ifndef __HIP_PLATFORM_AMD__ - -template -void mxfp8_quantize(const Tensor &input, const Tensor *act_input, - const Tensor *noop, // TODO (ksivamani) - Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { -#ifndef __HIP_PLATFORM_AMD__ - using namespace mxfp8_kernel; - checkCuDriverContext(stream); -#endif - - bool use_rowwise_scaling = output->has_data(); - bool use_colwise_scaling = output->has_columnwise_data(); - NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - - if (use_rowwise_scaling) { - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, - "Columnwise scaling tensor must be allocated"); - } - CheckNoopTensor(*noop, "cast_noop"); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - -#ifdef __HIP_PLATFORM_AMD__ - constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; - constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; - constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; -#else - constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); - - constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; - constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; - - constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; - constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; - constexpr size_t BUFF_DIM_Y = THREADS_Y; - constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; -#endif - - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); - const dim3 grid(blocks_X, blocks_Y); - const size_t block_size = THREADS_PER_CHUNK; - - const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; - const size_t scale_stride_colwise = - use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1; - - e8m0_t *const scales_rowwise_ptr = - use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; - e8m0_t *const scales_colwise_ptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - -#ifndef __HIP_PLATFORM_AMD__ - ScalingType scaling_type; - if (use_rowwise_scaling && (!use_colwise_scaling)) { - scaling_type = ScalingType::ROWWISE; - } else if ((!use_rowwise_scaling) && use_colwise_scaling) { - scaling_type = ScalingType::COLWISE; - } else if (use_rowwise_scaling && use_colwise_scaling) { - scaling_type = ScalingType::BIDIMENSIONAL; - } -#endif - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); - NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); - - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; - workspace->data.dtype = DType::kFloat32; - return; - } - } - - float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; - float *const amax_ptr = reinterpret_cast(output->amax.dptr); - const float *noop_ptr = reinterpret_cast(noop->data.dptr); - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, -#ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, - TRANSFORMER_ENGINE_SWITCH_CONDITION( - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel - <<>>( - reinterpret_cast(input.data.dptr), - (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->columnwise_data.dptr), - scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - ))); // NOLINT(*) -#else // #ifdef __HIP_PLATFORM_AMD__ - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, - cols, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, - BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, - BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - switch (scaling_type) { - case ScalingType::ROWWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::COLWISE: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - case ScalingType::BIDIMENSIONAL: - NVTE_CHECK_CUDA(cudaFuncSetAttribute( - cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); - - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); - NVTE_CHECK_CUDA(cudaGetLastError()); - break; - } -#endif // #ifdef __HIP_PLATFORM_AMD__ - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace detail { - -using Empty = transformer_engine::Empty; - -__device__ inline float identity(float value, const Empty &) { return value; } - -struct DequantizeParam { - const float *scale_inv; -}; - -__device__ inline float dequantize_func(float value, const DequantizeParam ¶m) { - return value * (*(param.scale_inv)); -} - -} // namespace detail - -/* HIPCC has strict rules for __device__ functions usage on host. - It forbids not only calling but also other ODR-use assigning to variables - https://github.com/llvm/llvm-project/issues/105825 - Use templated struct wrapper to work around - */ -template -struct ActivationType -{ - static constexpr auto op = OP; -}; - -template -void CastVectorizedUnaryKernelLauncher(const Tensor &input, const Tensor *noop, Tensor *output, - cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; -#else //#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; -#endif //#ifdef __HIP_PLATFORM_AMD__ - const size_t N = product(input.data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryKernelLauncher( - reinterpret_cast(input.data.dptr), - reinterpret_cast(noop->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -template -void CastVectorizedUnaryGradKernelLauncher(const Tensor &grad, const Tensor *input, Tensor *output, - cudaStream_t stream) { -#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (ActivationType::op == nullptr) ? ActivationType::op : ActivationType::op; -#else //#ifdef __HIP_PLATFORM_AMD__ - constexpr float (*UnaryOP)(float, const ParamOP &) = (OP == nullptr) ? detail::identity : OP; -#endif //#ifdef __HIP_PLATFORM_AMD__ - const size_t N = product(input->data.shape); - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input->data.dtype, IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( - output->data.dtype, OType, - if (!is_fp8_dtype(output->data.dtype) || is_tensor_scaling(output->scaling_mode)) { - constexpr int nvec = 32 / sizeof(IType); - VectorizedUnaryGradKernelLauncher( - reinterpret_cast(grad.data.dptr), - reinterpret_cast(input->data.dptr), - reinterpret_cast(output->data.dptr), - reinterpret_cast(output->scale.dptr), - reinterpret_cast(output->amax.dptr), - reinterpret_cast(output->scale_inv.dptr), N, {}, stream); - } else { - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - }); // NOLINT(*) - ); // NOLINT(*) -} - -namespace { - -#ifndef __HIP_PLATFORM_AMD__ -static bool is_full_tile_1D_tensor(const Tensor *const t) { - const size_t N = product(t->data.shape); - const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); - return isFullTile; -} - -bool dimensions_supported_by_TMA(const Tensor *const t) { - const size_t cols = t->flat_last_dim(); - constexpr size_t TMA_bytes = 16; - const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); - return cols % alignment_requirement == 0; -} -#endif //#ifndef __HIP_PLATFORM_AMD__ - -} // namespace - -#ifndef __HIP_PLATFORM_AMD__ -// Supported by the Arch >= 10.0 -template -void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DBIAS && !IS_DACT) { - if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 - cast_fp8_1D(input, output, stream); - } else { - // Unaligned - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } - } else if (!IS_DBIAS && IS_DACT) { - if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && - is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { - // Aligned AND FP8 (+dAct) - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } else { - // Unaligned - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - } else { - cast_fp8_2D(input, act_input, output, dbias, workspace, - stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize(input, act_input, noop, output, dbias, - workspace, stream); - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} - -// Supported by the Arch < 10.0 -template -void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const Tensor *noop, - Tensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { - if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { - // zhongboz: should we just ignore IS_ACT here? - NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); - } - switch (output->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (!IS_DACT) { - CastVectorizedUnaryKernelLauncher(input, noop, output, stream); - } else { - CastVectorizedUnaryGradKernelLauncher(input, act_input, output, stream); - } - break; - } - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); - } -} -#endif //#ifndef __HIP_PLATFORM_AMD__ - -template -void fp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, Tensor *output, - Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckNoopTensor(*noop, "cast_noop"); - CheckInputTensor(input, "cast_input"); - CheckOutputTensor(*output, "cast_output"); - - if constexpr (IS_DBIAS) { - NVTE_CHECK(dbias != nullptr); - CheckOutputTensor(*dbias, "dbias"); - } - if constexpr (IS_DACT) { - NVTE_CHECK(act_input != nullptr); - CheckInputTensor(*act_input, "activation_input"); - NVTE_CHECK(input.dtype() == act_input->dtype(), "Types of both inputs must match."); - NVTE_CHECK(input.data.shape == act_input->data.shape, "Shapes of both inputs must match."); - } - - NVTE_CHECK(!is_fp8_dtype(input.dtype()), "Input must be in higher precision."); - NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); - -#ifndef __HIP_PLATFORM_AMD__ - // NVIDIA - // Supported by the Arch >= 10.0 - if (is_supported_by_CC_100()) { - fp8_quantize_arch_ge_100(input, act_input, noop, output, - dbias, workspace, stream); - } else { // Supported by the Arch < 10.0 - fp8_quantize_arch_l_100(input, act_input, noop, output, - dbias, workspace, stream); - } -#else - // AMD - fp8_quantize_rocm(input, act_input, noop, output, - dbias, workspace, stream); -#endif //#ifndef __HIP_PLATFORM_AMD__ -} - -namespace detail { - -template -void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor output, - NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { - const Tensor *input_tensor; - const Tensor *activation_input_tensor; - if constexpr (IS_DBIAS || IS_DACT) { - // backward - input is incoming gradient - input_tensor = convertNVTETensorCheck(grad); - activation_input_tensor = convertNVTETensor(input); - } else { - // forward = input is activation input - input_tensor = convertNVTETensorCheck(input); - activation_input_tensor = nullptr; - } - auto output_tensor = convertNVTETensorCheck(output); - auto dbias_tensor = convertNVTETensor(dbias); - auto workspace_tensor = convertNVTETensor(workspace); - - const QuantizationConfig *quant_config_cpp = - reinterpret_cast(quant_config); - - // extract noop tensor from quant_config_cpp if it's not null - const NVTETensor noop = quant_config_cpp ? quant_config_cpp->noop_tensor : nullptr; - const auto noop_tensor = noop != nullptr ? *(convertNVTETensorCheck(noop)) : Tensor(); - - switch (output_tensor->scaling_mode) { - case NVTE_DELAYED_TENSOR_SCALING: { - if (output_tensor->has_columnwise_data()) { - NVTE_CHECK(output_tensor->has_data(), - "Quantizing in only the columnwise direction not supported yet!"); - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - cast_transpose(*input_tensor, noop_tensor, output_tensor, stream); - } else { - cast_transpose_fused( - *input_tensor, activation_input_tensor, output_tensor, dbias_tensor, workspace_tensor, - stream); - } - } else if (output_tensor->has_data()) { - fp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - } - break; - } - case NVTE_MXFP8_1D_SCALING: { - mxfp8_quantize( - *input_tensor, activation_input_tensor, &noop_tensor, output_tensor, dbias_tensor, - workspace_tensor, stream); - break; - } -#ifndef __HIP_PLATFORM_AMD__ - case NVTE_BLOCK_SCALING_2D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_2D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : true; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - quantize_transpose_square_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, - /*noop_tensor=*/noop_tensor.data, stream); - break; - } - case NVTE_BLOCK_SCALING_1D: { - // TODO(kwyss): IS_BIAS, IS_DACT, IS_ACT, ParamOP, OP parameters support. - NVTE_CHECK((!IS_DBIAS && !IS_DACT && !IS_ACT), - "IS_DBIAS, IS_DACT, and IS_ACT not implemented for NVTE_BLOCK_SCALING_1D"); - bool force_pow_2_scales = quant_config_cpp ? quant_config_cpp->force_pow_2_scales : false; - float epsilon = quant_config_cpp ? quant_config_cpp->amax_epsilon : 0.0f; - FP8BlockwiseRowwiseOption rowwise_option = FP8BlockwiseRowwiseOption::NONE; - FP8BlockwiseColumnwiseOption columnwise_option = FP8BlockwiseColumnwiseOption::NONE; - if (output_tensor->has_data()) { - bool rowwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - rowwise_option = rowwise_compact ? FP8BlockwiseRowwiseOption::ROWWISE_COMPACT - : FP8BlockwiseRowwiseOption::ROWWISE_GEMM_READY; - } - if (output_tensor->has_columnwise_data()) { - bool columnwise_compact = quant_config_cpp - ? quant_config_cpp->float8_block_scale_tensor_format == - Float8BlockScaleTensorFormat::COMPACT - : false; - columnwise_option = columnwise_compact - ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT - : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; - } - quantize_transpose_vector_blockwise( - input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, - output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, noop_tensor.data, stream); - break; - } -#endif - default: - NVTE_ERROR("Not implemented scaling mode: " + to_string(output_tensor->scaling_mode) + "."); - } -} - -} // namespace detail -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_CAST_KERNELS_CUH_ diff --git a/transformer_engine/common/util/curanddx.hpp b/transformer_engine/common/util/curanddx.hpp new file mode 100644 index 000000000..4d7c90a01 --- /dev/null +++ b/transformer_engine/common/util/curanddx.hpp @@ -0,0 +1,106 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ + +namespace transformer_engine { +namespace curanddx { +namespace detail { + +inline constexpr unsigned int philox4x32_w32_0 = 0x9E3779B9U; +inline constexpr unsigned int philox4x32_w32_1 = 0xBB67AE85U; +inline constexpr unsigned int philox4x32_m4x32_0 = 0xD2511F53U; +inline constexpr unsigned int philox4x32_m4x32_1 = 0xCD9E8D57U; + +__forceinline__ __device__ unsigned int mulhilo32(unsigned int a, unsigned int b, + unsigned int* hip) { + *hip = __umulhi(a, b); + return a * b; +} + +__forceinline__ __device__ uint4 single_round(uint4 ctr, uint2 key) { + unsigned int hi0; + unsigned int hi1; + unsigned int lo0 = mulhilo32(philox4x32_m4x32_0, ctr.x, &hi0); + unsigned int lo1 = mulhilo32(philox4x32_m4x32_1, ctr.z, &hi1); + + uint4 ret = {hi1 ^ ctr.y ^ key.x, lo1, hi0 ^ ctr.w ^ key.y, lo0}; + return ret; +} + +template +__forceinline__ __device__ uint4 multiple_rounds(uint4 c, uint2 k) { + for (unsigned int i = 0; i < Rounds - 1; i++) { + c = single_round(c, k); // 1 + k.x += philox4x32_w32_0; + k.y += philox4x32_w32_1; + } + return single_round(c, k); // Rounds +} + +template +struct philox4x32_native_state { + static constexpr unsigned int rounds = Rounds; + + uint4 ctr; + uint2 key; + + __forceinline__ __device__ void philox_state_incr() { + if (++ctr.x) return; + if (++ctr.y) return; + if (++ctr.z) return; + ++ctr.w; + } + + __forceinline__ __device__ void philox_state_incr(size_t n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + + ctr.x += nlo; + if (ctr.x < nlo) nhi++; + + ctr.y += nhi; + if (nhi <= ctr.y) return; + if (++ctr.z) return; + ++ctr.w; + } + + __forceinline__ __device__ void philox_state_incr_hi(size_t n) { + unsigned int nlo = (unsigned int)(n); + unsigned int nhi = (unsigned int)(n >> 32); + + ctr.z += nlo; + if (ctr.z < nlo) nhi++; + + ctr.w += nhi; + } + + // offset is the total # of 128bits generated with a single generate4() call + __forceinline__ __device__ void skip_offset(size_t n) { philox_state_incr(n); } + + __forceinline__ __device__ void skip_subsequence(size_t n) { philox_state_incr_hi(n); } + + __forceinline__ __device__ void init(size_t seed, size_t subsequence, size_t offset) { + ctr = make_uint4(0, 0, 0, 0); + key.x = (unsigned int)seed; + key.y = (unsigned int)(seed >> 32); + + skip_subsequence(subsequence); + skip_offset(offset); + } + + __forceinline__ __device__ uint4 generate4() { + auto tmp = multiple_rounds(ctr, key); + philox_state_incr(); + return tmp; + } +}; +} // namespace detail +} // namespace curanddx +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_COMMON_UTIL_CURANDDX_HPP_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 6ab5eb958..ebcf99afe 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -18,6 +18,10 @@ #endif // __HIP_PLATFORM_AMD__ #include +#ifndef __HIP_PLATFORM_AMD__ +#include "nccl.h" +#endif //#ifndef __HIP_PLATFORM_AMD__ + #ifdef NVTE_WITH_CUBLASMP #include #endif // NVTE_WITH_CUBLASMP @@ -121,4 +125,13 @@ #endif // NVTE_WITH_CUBLASMP +#ifndef __HIP_PLATFORM_AMD__ +#define NVTE_CHECK_NCCL(expr) \ + do { \ + const ncclResult_t status_NVTE_CHECK_NCCL = (expr); \ + if (status_NVTE_CHECK_NCCL != ncclSuccess) { \ + NVTE_ERROR("NCCL Error: ", ncclGetErrorString(status_NVTE_CHECK_NCCL)); \ + } \ + } while (false) +#endif //#ifndef __HIP_PLATFORM_AMD__ #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ diff --git a/transformer_engine/common/util/math.h b/transformer_engine/common/util/math.h index 2d425d675..005a60067 100644 --- a/transformer_engine/common/util/math.h +++ b/transformer_engine/common/util/math.h @@ -11,6 +11,11 @@ namespace transformer_engine { struct Empty {}; +struct ClampedSwiGLUParam { + float limit; + float alpha = 1.702f; // Default value for QuickGELU +}; + template __device__ inline OType gelu(const IType val, const Empty&) { const float cval = val; @@ -31,6 +36,8 @@ __device__ inline OType sigmoid(const IType val, const Empty&) { return 1.f / (1.f + expf(-cval)); } +__device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf(-x)); } + template __device__ inline OType dsigmoid(const IType val, const Empty& e) { const float cval = val; @@ -38,17 +45,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) { return s * (1.f - s); } +template +__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) { + const float cval = val; + Empty e = {}; + return cval * sigmoid(alpha * cval, e); +} + template __device__ inline OType qgelu(const IType val, const Empty& e) { + return qgelu_with_alpha(val, 1.702f); +} + +template +__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) { const float cval = val; - return cval * sigmoid(1.702f * cval, e); + Empty e = {}; + return alpha * cval * dsigmoid(alpha * cval, e) + + sigmoid(alpha * cval, e); } template __device__ inline OType dqgelu(const IType val, const Empty& e) { - const float cval = val; - return 1.702f * cval * dsigmoid(1.702f * cval, e) + - sigmoid(1.702f * cval, e); + return dqgelu_with_alpha(val, 1.702f); } template @@ -57,12 +76,26 @@ __device__ inline OType silu(const IType val, const Empty& e) { return cval * sigmoid(cval, e); } +template +__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) { + const float cval = min(p.limit, static_cast(val)); // Clamping + return qgelu_with_alpha(cval, p.alpha); +} + template __device__ inline OType dsilu(const IType val, const Empty& e) { const float cval = val; return cval * dsigmoid(cval, e) + sigmoid(cval, e); } +template +__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) { + const bool dclamp_val = static_cast(val) <= p.limit; + const float clamp_val = min(static_cast(val), p.limit); + const float dsilu_val = dqgelu_with_alpha(clamp_val, p.alpha); + return dclamp_val ? dsilu_val : 0.0f; +} + template __device__ inline OType relu(IType value, const Empty&) { return fmaxf(value, 0.f); diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 7c38a337b..312890db0 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -16,44 +16,175 @@ #include #include +#if CUDA_VERSION >= 12080 +#include +#endif // CUDA_VERSION >= 12080 + +#include "common/utils.cuh" +#ifdef __HIP_PLATFORM_AMD__ +#include "../util/vectorized_pointwise.h" +#endif //#ifndef __HIP_PLATFORM_AMD__ + namespace transformer_engine { + namespace ptx { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +#ifndef __HIP_PLATFORM_AMD__ +template +struct ArchSpecific { + constexpr static int id = N * 10; + + template + constexpr static bool compatible() { + if constexpr (CurrentArch == id) { + static_assert(ArchSpecific == CurrentArch, + "Compiled for the generic architecture, while utilizing arch-specific " + "features. Please compile for smXXXa architecture instead of smXXX " + "architecture."); + return true; + } else { + return false; + } + } +}; + +template +struct FamilySpecific { + constexpr static int id = N * 10; + + template + constexpr static bool compatible() { + if constexpr ((CurrentArch / 100) == (id / 100)) { + static_assert(FamilySpecific == CurrentArch, + "Compiled for the generic architecture, while utilizing family-specific " + "features. Please compile for smXXXf architecture instead of smXXX " + "architecture."); + return true; + } else { + return false; + } + } +}; + +template +constexpr bool is_supported_arch() { + if constexpr (T::template compatible()) { + return true; + } else if constexpr (sizeof...(U) != 0) { + return is_supported_arch(); + } else { + return false; + } +} + +#if CUDA_VERSION < 12090 +#if __CUDA_ARCH_HAS_FEATURE__(SM90_ALL) +#define __CUDA_ARCH_SPECIFIC__ 900 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 900 +#endif +#if __CUDA_ARCH_HAS_FEATURE__(SM100_ALL) +#define __CUDA_ARCH_SPECIFIC__ 1000 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1000 +#endif +#if __CUDA_ARCH_HAS_FEATURE__(SM101_ALL) +#define __CUDA_ARCH_SPECIFIC__ 1010 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1010 +#endif +#if __CUDA_ARCH_HAS_FEATURE__(SM120_ALL) +#define __CUDA_ARCH_SPECIFIC__ 1200 +#define __CUDA_ARCH_FAMILY_SPECIFIC__ 1200 +#endif +#endif + +#ifdef __CUDA_ARCH__ +#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = __CUDA_ARCH__; +#else +#define __NVTE_CURRENT_ARCH__ constexpr int current_arch = 0; +#endif + +#ifdef __CUDA_ARCH_SPECIFIC__ +#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = __CUDA_ARCH_SPECIFIC__; +#else +#define __NVTE_ARCH_SPECIFIC__ constexpr int ArchSpecific = 0; +#endif + +#ifdef __CUDA_ARCH_FAMILY_SPECIFIC__ +#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = __CUDA_ARCH_FAMILY_SPECIFIC__; +#else +#define __NVTE_ARCH_FAMILY_SPECIFIC__ constexpr int FamilySpecific = 0; +#endif + +#define NVTE_CUDA_ARCH_MATCHES(...) \ + [&] { \ + __NVTE_CURRENT_ARCH__ \ + __NVTE_ARCH_SPECIFIC__ \ + __NVTE_ARCH_FAMILY_SPECIFIC__ \ + return transformer_engine::ptx::is_supported_arch(); \ + }(); + +#define ARCH_BLACKWELL_FAMILY \ + NVTE_CUDA_ARCH_MATCHES(ptx::FamilySpecific<100>, ptx::FamilySpecific<110>, \ + ptx::FamilySpecific<120>) +#define ARCH_HAS_STOCHASTIC_ROUNDING \ + NVTE_CUDA_ARCH_MATCHES(ptx::ArchSpecific<100>, ptx::ArchSpecific<103>) + +#endif //#ifndef __HIP_PLATFORM_AMD__ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init __device__ __forceinline__ void mbarrier_init(uint64_t *mbar, const uint32_t count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.init.shared.b64 [%0], %1;" ::"r"(mbar_ptr), "r"(count) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_init is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval __device__ __forceinline__ void mbarrier_invalid(uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.inval.shared.b64 [%0];" ::"r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_invalid is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive(uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" ::"r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_arrive is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-arrive __device__ __forceinline__ void mbarrier_arrive_expect_tx(uint64_t *mbar, const uint32_t tx_count) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" ::"r"(mbar_ptr), "r"(tx_count) : "memory"); +#else + NVTE_DEVICE_ERROR("mbarrier_arrive_expect_tx is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void fence_mbarrier_init_release_cluster() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile("fence.mbarrier_init.release.cluster;"); +#else + NVTE_DEVICE_ERROR("fence_mbarrier_init_release_cluster is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // global -> shared::cluster __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( uint64_t *dst_shmem, const uint64_t *src_global_ptr, const uint32_t size, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); // triggers async copy, i.e. the thread continues until wait() on mbarrier @@ -65,6 +196,9 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( ".mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];" ::"r"(dst_shmem_ptr), "l"(src_global_ptr), "r"(size), "r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_global_to_shared is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -72,6 +206,7 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_global_to_shared( __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( uint64_t *dst_shmem, const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, uint64_t *mbar) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t dst_shmem_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); // triggers async copy, i.e. the thread continues until wait() on mbarrier @@ -83,9 +218,13 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" ::"r"(dst_shmem_ptr), "l"(tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(mbar_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_global_to_shared is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, const uint32_t parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t waitComplete; asm volatile( "{\n\t .reg .pred P_OUT; \n\t" @@ -96,15 +235,21 @@ __device__ __forceinline__ bool mbarrier_try_wait_parity(uint32_t mbar_ptr, cons : "r"(mbar_ptr), "r"(parity) : "memory"); return static_cast(waitComplete); +#else + NVTE_DEVICE_ERROR("mbarrier_try_wait_parity is only supported on SM 10.0+."); +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + return true; } __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint32_t parity) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); while (!mbarrier_try_wait_parity(mbar_ptr, parity)) { } -} - +#else + NVTE_DEVICE_ERROR("mbarrier_wait_parity is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} constexpr uint32_t FP32_MANTISSA_BITS = 23; constexpr uint32_t FP32_EXPONENT_BIAS = 127; @@ -120,53 +265,57 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { } __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { -#ifdef __HIP_PLATFORM_AMD__ -#define __CUDA_ARCH_HAS_FEATURE__(x) 0 -#endif -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; - } - if (isinf(val)) { - return 0xFE; - } - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; +#ifndef __HIP_PLATFORM_AMD__ + constexpr bool is_blackwell = false; + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); + } else { +#endif //#ifndef __HIP_PLATFORM_AMD__ + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#ifndef __HIP_PLATFORM_AMD__ } - return exponent; -#endif +#endif //#ifndef __HIP_PLATFORM_AMD__ } -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // shared::cta -> global __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_t *dst_global_ptr, const uint64_t *src_shmem, const uint32_t size) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;" ::"l"(dst_global_ptr), "r"(src_shmem_ptr), "r"(size) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_1d_shared_to_global is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -174,51 +323,93 @@ __device__ __forceinline__ void cp_async_bulk_tensor_1d_shared_to_global(uint64_ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( const uint64_t *tensor_map_ptr, const uint32_t offset_x, const uint32_t offset_y, uint64_t *src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile("cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" ::"l"( tensor_map_ptr), "r"(offset_x), "r"(offset_y), "r"(src_shmem_ptr) : "memory"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group __device__ __forceinline__ void cp_async_bulk_wait_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group template __device__ __forceinline__ void cp_async_bulk_wait_group_read() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 1;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 2;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template <> __device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.wait_group.read 4;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group __device__ __forceinline__ void cp_async_bulk_commit_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.commit_group;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_commit_group is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // Proxy fence (bi-directional): -__device__ __forceinline__ void fence_proxy_async() { asm volatile("fence.proxy.async;"); } +__device__ __forceinline__ void fence_proxy_async() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.async;"); +#else + NVTE_DEVICE_ERROR("fence_proxy_async is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} __device__ __forceinline__ void fence_proxy_async_shared_cta() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("fence.proxy.async.shared::cta;"); +#else + NVTE_DEVICE_ERROR("fence_proxy_async_shared_cta is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } template @@ -227,21 +418,264 @@ struct alignas(2 * sizeof(T)) FPx2 { T y; }; +template +struct FPx4 { + T x1; + T x2; + T x3; + T x4; +}; + +template +struct Type2x {}; + +template <> +struct Type2x { + using type = float2; +}; + +template <> +struct Type2x { + using type = __nv_bfloat162; +}; + +template <> +struct Type2x { + using type = __half2; +}; + using floatx2 = FPx2; using bf16x2 = FPx2; using fp16x2 = FPx2; using fp8e4m3x2 = FPx2; using fp8e5m2x2 = FPx2; +using floatx4 = FPx4; +using bf16x4 = FPx4; +using fp16x4 = FPx4; +using fp8e4m3x4 = FPx4; +using fp8e5m2x4 = FPx4; + static_assert(sizeof(floatx2) == 8); static_assert(sizeof(bf16x2) == 4); static_assert(sizeof(fp16x2) == 4); static_assert(sizeof(fp8e4m3x2) == 2); static_assert(sizeof(fp8e5m2x2) == 2); +#if FP4_TYPE_SUPPORTED +using fp4e2m1 = __nv_fp4_e2m1; +using fp4e2m1x2 = __nv_fp4x2_e2m1; +using fp4e2m1x4 = __nv_fp4x4_e2m1; +static_assert(sizeof(fp4e2m1x2) == 1); +static_assert(sizeof(fp4e2m1x4) == 2); + +// When converting to .e2m1x2 data formats, the destination operand d has .b8 type. +// When converting two .f32 inputs to .e2m1x2, each input is converted to the specified format, +// and the converted values are packed in the destination operand d such that the value +// converted from input a is stored in the upper 4 bits of d and the value converted +// from input b is stored in the lower 4 bits of d. + +// SIMD like "Fused" cast + multiplication (x4) +template +__device__ __forceinline__ void mul_cvt_4x(fp4e2m1x4 &out, const Tx2 &in01, const Tx2 &in23, + const float scale) { + const float x0 = static_cast(in01.x) * scale; + const float x1 = static_cast(in01.y) * scale; + const float x2 = static_cast(in23.x) * scale; + const float x3 = static_cast(in23.y) * scale; + out = fp4e2m1x4(make_float4(x0, x1, x2, x3)); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding( + const uint64_t in_4x, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %3; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x_with_rn(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b16 v0_bf16; \n\t" + ".reg.b16 v1_bf16; \n\t" + ".reg.b16 v2_bf16; \n\t" + ".reg.b16 v3_bf16; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0_bf16, v1_bf16, v2_bf16, v3_bf16} , %1; \n\t" + "cvt.f32.bf16 v0, v0_bf16; \n\t" + "cvt.f32.bf16 v1, v1_bf16; \n\t" + "cvt.f32.bf16 v2, v2_bf16; \n\t" + "cvt.f32.bf16 v3, v3_bf16; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %2; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %2; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(in_4x), "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_bf16_to_fp4_4x(const uint64_t in_4x, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_bf16_to_fp4_4x_with_stochastic_rounding(in_4x, scale, rbits); + } else { + return mul_cvt_bf16_to_fp4_4x_with_rn(in_4x, scale, rbits); + } +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding( + const float2 in01, const float2 in23, const float2 scale, const uint32_t rbits) { + uint16_t out_4x = 0; + constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING; + if constexpr (has_rs) { + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rs.satfinite.e2m1x4.f32 %0, {v2, v3, v0, v1}, %4; \n\t" // mind the shuffled elements order + "}" + : "=h"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale)), "r"(rbits)); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return *reinterpret_cast(&out_4x); +} + +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x_with_rn(const float2 in01, + const float2 in23, + const float2 scale, + const uint32_t rbits) { + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + uint32_t out_4x = 0; // Only need 16 bit. Using 32 bit container for packing. + if constexpr (is_blackwell) { + // NOTE: rbits unused for rn. + asm volatile( + "{\n" + ".reg.b64 v01; \n\t" + ".reg.b64 v23; \n\t" + ".reg.b32 v0; \n\t" + ".reg.b32 v1; \n\t" + ".reg.b32 v2; \n\t" + ".reg.b32 v3; \n\t" + ".reg.b8 f0; \n\t" + ".reg.b8 f1; \n\t" + "mov.b64 {v0, v1} , %1; \n\t" + "mov.b64 {v2, v3} , %2; \n\t" + "mov.b64 v01, {v0, v1}; \n\t" + "mov.b64 v23, {v2, v3}; \n\t" + "mul.f32x2 v01, v01, %3; \n\t" // mind the shuffled elements order + "mul.f32x2 v23, v23, %3; \n\t" // mind the shuffled elements order + "mov.b64 {v1, v0}, v01; \n\t" + "mov.b64 {v3, v2}, v23; \n\t" + "cvt.rn.satfinite.e2m1x2.f32 f0, v0, v1;\n\t" + "cvt.rn.satfinite.e2m1x2.f32 f1, v2, v3;\n\t" + "mov.b32 %0, {f0, f1, f0, f1};\n\t" + "}" + : "=r"(out_4x) + : "l"(reinterpret_cast(in01)), + "l"(reinterpret_cast(in23)), + "l"(reinterpret_cast(scale))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + return reinterpret_cast(&out_4x)[0]; +} + +template +__device__ __forceinline__ fp4e2m1x4 mul_cvt_fp32_to_fp4_4x(const float2 in01, const float2 in23, + const float2 scale, + const uint32_t rbits) { + if constexpr (USE_STOCHASTIC_ROUNDING) { + return mul_cvt_fp32_to_fp4_4x_with_stochastic_rounding(in01, in23, scale, rbits); + } else { + return mul_cvt_fp32_to_fp4_4x_with_rn(in01, in23, scale, rbits); + } +} +#endif // FP4_TYPE_SUPPORTED + // SIMD like "Fused" cast + multiplication (x2) __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair; \n\t" @@ -254,10 +688,14 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, : "=h"(reinterpret_cast(out)) : "l"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair; \n\t" @@ -270,9 +708,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, : "=h"(reinterpret_cast(out)) : "l"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -292,9 +734,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -314,9 +760,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -336,9 +786,13 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) asm volatile( "{\n" ".reg.b64 val_pair_before; \n\t" @@ -358,24 +812,33 @@ __device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, con : "=h"(reinterpret_cast(out)) : "r"(reinterpret_cast(in)), "l"(reinterpret_cast(scale))); +#else + NVTE_DEVICE_ERROR("mul_cvt_2x is only supported on SM 10.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } __device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" : "=r"(reinterpret_cast(dst)) : "r"(reinterpret_cast(p1)), "r"(reinterpret_cast(p2))); +#else + NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } __device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" : "=r"(reinterpret_cast(dst)) : "r"(reinterpret_cast(p1)), "r"(reinterpret_cast(p2))); +#else + NVTE_DEVICE_ERROR("abs_max_2x is only supported on SM 8.9+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 890) } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - } // namespace ptx namespace { @@ -393,6 +856,8 @@ __forceinline__ __device__ void initialize_barriers(uint64_t *mbar, const bool i } // Syncthreads so initialized barrier is visible to all threads. __syncthreads(); +#else + NVTE_DEVICE_ERROR("initialize_barriers is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -408,6 +873,8 @@ __forceinline__ __device__ void destroy_barriers(uint64_t *mbar, const bool is_m ptx::mbarrier_invalid(&mbar[iter]); } } +#else + NVTE_DEVICE_ERROR("destroy_barriers is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -427,9 +894,12 @@ __forceinline__ __device__ void copy_1d_to_shared(void *dst, const void *src, // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_1d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } + __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, const size_t chunk_X, const size_t chunk_Y, const size_t num_bytes, uint64_t *barrier, const bool is_master_thread) { @@ -446,6 +916,8 @@ __forceinline__ __device__ void copy_2d_to_shared(void *dst, const void *src, co // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_2d_to_shared is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -472,6 +944,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx2(void *dst, const void *src, // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_2d_to_sharedx2 is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } @@ -501,6 +975,8 @@ __forceinline__ __device__ void copy_2d_to_sharedx3( // Other threads just arrive ptx::mbarrier_arrive(barrier); } +#else + NVTE_DEVICE_ERROR("copy_2d_to_sharedx3 is only supported on SM 10.0+."); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index b243a8a0b..301748d06 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -108,7 +108,8 @@ .value("kFloat16", transformer_engine::DType::kFloat16) \ .value("kBFloat16", transformer_engine::DType::kBFloat16) \ .value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \ - .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2); \ + .value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \ + .value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \ pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) \ .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \ .value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \ @@ -122,6 +123,10 @@ .value("NVTE_CAUSAL_BOTTOM_RIGHT_MASK", NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK) \ .value("NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK", \ NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK); \ + pybind11::enum_(m, "NVTE_Softmax_Type", pybind11::module_local()) \ + .value("NVTE_VANILLA_SOFTMAX", NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) \ + .value("NVTE_OFF_BY_ONE_SOFTMAX", NVTE_Softmax_Type::NVTE_OFF_BY_ONE_SOFTMAX) \ + .value("NVTE_LEARNABLE_SOFTMAX", NVTE_Softmax_Type::NVTE_LEARNABLE_SOFTMAX); \ pybind11::enum_(m, "NVTE_QKV_Format", pybind11::module_local()) \ .value("NVTE_BSHD", NVTE_QKV_Format::NVTE_BSHD) \ .value("NVTE_SBHD", NVTE_QKV_Format::NVTE_SBHD) \ diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 0d667a0ec..dd6869e02 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -11,7 +11,7 @@ #include "../common.h" #include "../utils.cuh" - +#include "math.h" namespace transformer_engine { /* \brief Helper class that enables storing multiple values of type DType @@ -338,7 +338,7 @@ template void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, - const Param params, cudaStream_t stream) { + const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, output); @@ -372,7 +372,7 @@ template void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, OutputType *output, const fp32 *scale, fp32 *amax, - fp32 *scale_inv, const size_t N, const Param params, + fp32 *scale_inv, const size_t N, const Param ¶ms, cudaStream_t stream) { if (N != 0) { auto align = CheckAlignment(N, nvec, input, grad, output); @@ -431,7 +431,13 @@ __launch_bounds__(unary_kernel_threads) __global__ #pragma unroll for (int i = 0; i < nvec; ++i) { const ComputeType val = static_cast(loader0.separate()[i]); - const ComputeType val2 = static_cast(loader1.separate()[i]); + ComputeType val2 = static_cast(loader1.separate()[i]); + + if constexpr (std::is_same::value) { + // Clamp the gated value and add 1 at the end + ComputeType limit = p.limit; + val2 = std::min(std::max(-limit, val2), limit) + 1; + } ComputeType temp = static_cast(Activation(val, p) * val2); if (requires_amax) { __builtin_assume(max >= 0); @@ -532,10 +538,18 @@ __launch_bounds__(unary_kernel_threads) __global__ for (int i = 0; i < nvec; ++i) { const ComputeType grad_val = static_cast(grad_loader.separate()[i]); const ComputeType gelu_in = static_cast(input_loader0.separate()[i]); - const ComputeType gate_in = static_cast(input_loader1.separate()[i]); + ComputeType gate_in = static_cast(input_loader1.separate()[i]); + bool dgate_in = true; + + if constexpr (std::is_same::value) { + // In case of GPT OSS, clamp the activation and gate values + const ComputeType limit = p.limit; + dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp + gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f; + } ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; - ComputeType after_dgate = grad_val * Activation(gelu_in, p); + ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f; if (requires_amax) { __builtin_assume(max >= 0); diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 799becaee..c56242d34 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -31,6 +31,7 @@ typedef uint16_t hip_bfloat16x2 __attribute__((ext_vector_type(2))); #endif #if !defined(__CUDACC_RTC__) +#include #include #else // Importing C++ standard headers is a pain with NVRTC @@ -52,6 +53,39 @@ constexpr uint32_t THREADS_PER_WARP = 32; //////////////////////////////////////////////////////////////////////////////////////////////////// +// Device-side error +#define NVTE_DEVICE_ERROR(message) \ + do { \ + printf("%s:%d in function %s (thread (%d,%d,%d), block (%d,%d,%d)): %s\n", __FILE__, __LINE__, \ + __func__, threadIdx.x, threadIdx.y, threadIdx.z, blockIdx.x, blockIdx.y, blockIdx.z, \ + (message)); \ + assert(0); \ + } while (false) + +// Device-side error on thread 0 +#define NVTE_DEVICE_THREAD0_ERROR(message) \ + do { \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && \ + threadIdx.y == 0 && threadIdx.z == 0) { \ + NVTE_DEVICE_ERROR(message); \ + } \ + } while (false) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 operator+(const float2 &a, const float2 &b) { // NOLINT(*) + return {a.x + b.x, a.y + b.y}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void operator+=(float2 &a, const float2 &b) { // NOLINT(*) + a.x += b.x; + a.y += b.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template struct Sum { inline __device__ Sum() {} diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index c8a31a343..4dee97b70 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -7,19 +7,55 @@ from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.api import TEConfigAPIMapper +# Module-level counters for tracking invocations +# NOTE: These must be accessed via the full module path +# (transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count) +# to ensure the same module instance is used when the feature is loaded by the debug framework +# and when imported by tests. Using just the variable name would create separate instances +# in different import contexts. +_inspect_tensor_enabled_call_count = 0 +_inspect_tensor_call_count = 0 + @Registry.register_feature(namespace="transformer_engine") class TestDummyFeature(TEConfigAPIMapper): """ - This is feature used only in tests. It invokes look_at_tensor_before_process - and does nothing. + This is feature used only in tests. It invokes inspect_tensor and does nothing. If no features are used, then TE layer automatically switches to the non-debug mode. This feature is invoked for each GEMM to prevent this behavior. + + Config options: + - inspect_only_once: if True, return (False, None) from inspect_tensor_enabled to test caching behavior + + Note: This feature always tracks invocations for testing purposes. """ @api_method - def inspect_tensor_enabled(self, *_args, **_kwargs): - """API call used to determine whether to run look_at_tensor_before_process - in the forward pass.""" + def inspect_tensor_enabled(self, config, *_args, **_kwargs): + """API call used to determine whether to run inspect_tensor in the forward pass. + + Always tracks calls for testing purposes. + + Returns: + - If inspect_only_once=True in config: returns (False, None) - check once, never call inspect_tensor + - Otherwise: returns True - feature is always enabled + """ + # Access counter via full module path to ensure we're modifying the same module-level + # variable regardless of import context (debug framework vs test import) + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self + + dummy_feature._inspect_tensor_enabled_call_count += 1 + + inspect_only_once = config.get("inspect_only_once", False) + if inspect_only_once: + return False, None return True + + @api_method + def inspect_tensor(self, _config, *_args, **_kwargs): + """This method does nothing but always tracks invocations for testing.""" + # Access counter via full module path to ensure shared state across import contexts + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self + + dummy_feature._inspect_tensor_call_count += 1 diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index 4b01f9712..58c7379b5 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.tensor import Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.fp8 import _default_sf_compute +from transformer_engine.pytorch.quantization import _default_sf_compute def fake_quantize(tensor: torch.Tensor, fp8_format: tex.DType, out=None): diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index 31620211d..d09fb1057 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -290,10 +290,16 @@ def inspect_tensor( for stat in config["stats"]: self.check_if_stat_is_supported(stat, recipe_name) + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + options = ( - config.get("start_step", None), - config.get("end_step", None), - config.get("start_end_list", None), + start_step, + end_step, + start_end_list, "fp8", ) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 7ba2f9f77..ff37e659a 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -4,7 +4,7 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Optional +from typing import Dict, Optional, List import torch @@ -15,10 +15,14 @@ from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase -from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import Float8TensorStorage +from transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params +from transformer_engine.debug.features.utils.stats_computation import ( + add_max_blockwise_dynamic_range_stats, + BlockwiseDynamicRangeStat, +) @Registry.register_feature(namespace="transformer_engine") @@ -44,7 +48,14 @@ class LogTensorStats(BaseLogTensorStats): - l1_norm - l2_norm - cur_amax – maximal absolute value of a tensor, - - dynamic_range – equal to `torch.log2(amax) - torch.log2(amin)` + - dynamic_range – equal to `torch.log2(amax) - torch.log2(nonzero_amin)` + - max_blockwise_dynamic_range – Computes the maximum dynamic range `log2(amax) - log2(nonzero_amin)` across all blocks of size block_size within the tensor. + If tensor and its transpose is needed in training, this stat is computed for both orientations and the maximum is returned. + For `dim=1` there are block_size consecutive elements in the block, for `dim=2` the block is block_size x block_size elements tile. + + - block_size: int, default = 32 + - dims: int, default = 1, allowed values are 1 and 2 + tensors/tensors_struct: List[str] list of tensors to log @@ -88,6 +99,60 @@ class LogTensorStats(BaseLogTensorStats): stats: [dynamic_range] """ + def _is_supported_stat(self, stat: str | Dict): + """Returns True if the stat is supported by this feature, False otherwise.""" + if isinstance(stat, dict): + stat_name = list(stat.keys())[0] + if stat_name == "max_blockwise_dynamic_range": + stat_dict = stat[stat_name] + if not isinstance(stat_dict, dict): + return False + # Ensure only supported keys are present + allowed_keys = {"block_size", "dims"} + if any(k not in allowed_keys for k in stat_dict.keys()): + return False + block_size = stat_dict.get("block_size", 32) + dims = stat_dict.get("dims", 1) + # Type and value validation + if not isinstance(block_size, int) or not isinstance(dims, int): + return False + if block_size > 0 and dims in [1, 2]: + return True + return False + return stat in BaseLogTensorStats._get_supported_stats_list(None) | { + "cur_amax", + "dynamic_range", + } + + def _parse_max_blockwise_dynamic_range_stats( + self, stats: List[str | Dict], tensor_name: str + ) -> List[str | BlockwiseDynamicRangeStat]: + """ + Adds all max_blockwise_dynamic_range stats to the stat computation logic. + Changes the types of the stats from Dict to BlockwiseDynamicRangeStat named tuple, + for other stats nothing is changed. + + For example, if the stats is [{"max_blockwise_dynamic_range": {"block_size": 32, "dims": 1}}], + it will be changed to [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=True)] + or [BlockwiseDynamicRangeStat(block_size=32, dims=1, max_over_orientations=False)] depending on tensor_name. + + """ + max_over_orientations = tensor_name in ["activation", "weight"] + parsed_stats = [] + for stat in stats: + if isinstance(stat, dict): + block_size = stat["max_blockwise_dynamic_range"].get("block_size", 32) + dims = stat["max_blockwise_dynamic_range"].get("dims", 1) + + # Register stat and return the named tuple + parsed_stat = add_max_blockwise_dynamic_range_stats( + block_size, dims, max_over_orientations + ) + parsed_stats.append(parsed_stat) + else: + parsed_stats.append(stat) + return parsed_stats + def _get_supported_stats_list(self): """Returns stats this feature can log.""" return BaseLogTensorStats._get_supported_stats_list(None) | {"cur_amax", "dynamic_range"} @@ -123,17 +188,23 @@ def inspect_tensor( """API call used to collect the data about the tensor before process_tensor()/quantization.""" assert ( - type(tensor) not in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase] + type(tensor) not in [Float8Tensor, Float8TensorStorage, MXFP8Tensor, MXFP8TensorStorage] and tensor.dtype != torch.uint8 ), ( f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be in high precision when using" " log_tensor_stats. Use log_fp8_tensor_stats for FP8 tensors." ) + start_step = config.get("start_step", None) + end_step = config.get("end_step", None) + start_end_list = config.get("start_end_list", None) + if start_end_list is not None: + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + options = ( - config.get("start_step", None), - config.get("end_step", None), - config.get("start_end_list", None), + start_step, + end_step, + start_end_list, ) skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( @@ -141,14 +212,16 @@ def inspect_tensor( ) for stat in config["stats"]: - assert ( - stat in self._get_supported_stats_list() + assert self._is_supported_stat( + stat ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." + stats = self._parse_max_blockwise_dynamic_range_stats(config["stats"], tensor_name) + STATS_BUFFERS.try_add_buffer( layer_name=layer_name, tensor_name=tensor_name, - stats=config["stats"], + stats=stats, options=options, reduction_group=reduction_group, reduce_within_microbatch=reduce_within_microbatch, diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index f07602d23..b5b462f5a 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -130,8 +130,12 @@ def log(self): for stat_name in self.stats_to_log: combiner = STATS[stat_name][1] stat_value = combiner(gathered_helper_stats) + + # Convert stat key to string for logging (uses __str__ for named tuples) + stat_name_str = str(stat_name) + MetricLogger.log_scalar( - f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration + f"{self.layer_name}_{self.tensor_name}_{stat_name_str}", stat_value, self.iteration ) output[(self.layer_name, self.tensor_name, stat_name, self.iteration)] = ( stat_value # for debugging purposes @@ -172,11 +176,19 @@ def _if_run_reduction(self) -> bool: if self.at_least_one_layer_fed: return True iteration = TEDebugState.get_iteration() - for _, next_iter in self.layers_to_next_iter.items(): + layers_to_remove = [] + for layer_name, next_iter in self.layers_to_next_iter.items(): + # When next_iter is None the feature will no longer run. + if next_iter is None: + layers_to_remove.append(layer_name) + continue # Note that layer can be not run for many iterations, # in this case we will synchronize until every step until we get any information from it. if iteration >= next_iter: return True + + for layer_name in layers_to_remove: + self.layers_to_next_iter.pop(layer_name, None) return False def reset(self): diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 2fa6985ac..8c480441c 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -7,12 +7,25 @@ """ import math +from collections import namedtuple + import torch import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.common.recipe import Format +class BlockwiseDynamicRangeStat( + namedtuple("BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"]) +): + """Named tuple representing a blockwise dynamic range statistic configuration.""" + + def __str__(self) -> str: + """Convert to string representation for stat name. Used for logging.""" + suffix = "_max_over_orientations" if self.max_over_orientations else "" + return f"max_blockwise_dynamic_range_block_size_{self.block_size}_dims_{self.dims}{suffix}" + + @torch.compile def _compute_dynamic_range_top(tensor): """Computes the log2 of the amax of the tensor""" @@ -26,6 +39,7 @@ def _compute_dynamic_range_top(tensor): return torch.log2(amax) +@torch.compile def _compute_dynamic_range_bottom(tensor): """Computes the log2 of the amin of the tensor""" tensor_abs = tensor.abs() @@ -37,6 +51,76 @@ def _compute_dynamic_range_bottom(tensor): return torch.log2(amin) +def compute_max_blockwise_dynamic_range(tensor, stat_config): + """ + Computes maximum blockwise dynamic range (log2 max/min_nonzero) within blocks. + + Flattens tensor to 2D and computes maximum dynamic range within blocks. If max_over_orientations + is True, computes for both rowwise and columnwise orientations and returns the maximum, + capturing the worst-case scenario regardless of how the tensor is used in GEMM operations. + If False, computes only for rowwise orientation. + + Returns 0 if all blocks are zeros, otherwise computes dynamic range over non-zero blocks. + + Args: + tensor: Input tensor (will be flattened to 2D) + stat_config: BlockwiseDynamicRangeStat named tuple with: + - block_size: Size of blocks (int) + - dims: 1 for 1D blocks (consecutive elements), 2 for 2D blocks (tiles) + - max_over_orientations: If True, compute max over rowwise and columnwise orientations + """ + # Extract parameters from stat_config + block_size = stat_config.block_size + dims = stat_config.dims + max_over_orientations = stat_config.max_over_orientations + + def _compute_for_one_orientation(tensor): + total_numel = tensor.numel() + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + + # torch.compile friendly code - standard ** power does not work with jit + total_block_size = block_size * block_size if dims == 2 else block_size + assert ( + total_numel % total_block_size == 0 + ), f"Tensor numel ({total_numel}) is not divisible by block_size ({block_size})." + + tensor = tensor.abs().float() + if dims == 1: + tensor = tensor.reshape(-1, block_size) + per_block_amax = tensor.amax(dim=1) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=1) + else: + # We want to have tensor of shape [nr_blocks, block_size, block_size], + # where each block is a block_size x block_size tile of the original tensor. + dim_y = tensor.shape[-1] // block_size + tensor = ( + tensor.reshape(-1, block_size, dim_y, block_size) + .permute(0, 2, 1, 3) + .reshape(-1, block_size, block_size) + ) + per_block_amax = tensor.amax(dim=(1, 2)) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) + + # Identify blocks that contain any non-zero element + nonzero_blocks = per_block_amax != 0 + dynamic_range_per_block = torch.where( + nonzero_blocks, + torch.log2(per_block_amax) - torch.log2(per_block_amin), + torch.zeros_like(per_block_amax, dtype=torch.float32), + ) + return dynamic_range_per_block.max() + + # Flatten to 2D + tensor_2d = tensor.reshape(-1, tensor.shape[-1]) + if max_over_orientations: + return max( + _compute_for_one_orientation(tensor_2d), # Rowwise orientation + _compute_for_one_orientation(tensor_2d.transpose(-2, -1)), # Columnwise orientation + ) + return _compute_for_one_orientation(tensor_2d) + + +@torch.compile def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) @@ -45,6 +129,7 @@ def compute_variance(variances, numels, sums): return var +@torch.compile def compute_std(variances, numels, sums): """Computates standard deviation.""" return torch.sqrt(compute_variance(variances, numels, sums)) @@ -316,6 +401,37 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} +def add_max_blockwise_dynamic_range_stats( + block_size: int, dims: int, max_over_orientations: bool = False +): + """Register max_blockwise_X_dynamic_range stats for the recipe. + + Args: + block_size: Size of blocks for computing blockwise dynamic range + dims: 1 for 1D blocks, 2 for 2D blocks + max_over_orientations: Whether to compute max over rowwise and columnwise orientations + + Returns: + BlockwiseDynamicRangeStat named tuple representing this stat (used as the stat key) + """ + # Use named tuple directly as the stat key - this is cleaner than string keys + stat_key = BlockwiseDynamicRangeStat(block_size, dims, max_over_orientations) + + if stat_key in stats_to_num: + return stat_key # already registered + + assert dims in [1, 2], f"dims must be 1 or 2, got {dims}" + stats_to_num[stat_key] = len(stats_to_num) + DEPENDENCIES[stat_key] = {stat_key} + + STATS[stat_key] = ( + lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range(x, _stat_key), + lambda buffers, _stat_key=stat_key: max(_get(buffers, _stat_key)), + ) + + return stat_key + + for _columnwise in [True, False]: for _recipe_name in [ "", # default recipe diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index d564ca8e9..7f45a24e2 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -15,10 +15,10 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( QuantizedTensor, Quantizer, - QuantizedTensorBase, + QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) @@ -557,7 +557,7 @@ def set_usage(self, rowwise: bool = None, columnwise: bool = None): self._update_parent_quantizer_usage() -class DebugQuantizedTensor(QuantizedTensorBase): +class DebugQuantizedTensor(QuantizedTensorStorage): """ Class containing quantized tensors after debug. Depending on configuration it can contain one or two different objects. These objects can be accessed by the method diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 0b5e43402..6259a7ad8 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -34,7 +34,7 @@ from . import flax from . import quantize -from .quantize import fp8_autocast, update_collections, get_delayed_scaling +from .quantize import autocast, fp8_autocast, update_collections from .quantize import NVTE_FP8_COLLECTION_NAME from .sharding import MeshResource @@ -45,9 +45,9 @@ __all__ = [ "NVTE_FP8_COLLECTION_NAME", + "autocast", "fp8_autocast", "update_collections", - "get_delayed_scaling", "MeshResource", "flax", "quantize", diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index 12b35ec43..daa3679c4 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -11,7 +11,6 @@ import jax import jax.numpy as jnp - from . import cpp_extensions as tex from .quantize.tensor import NoScaleTensor @@ -22,6 +21,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[tex.activation.ActivationParams] = None, ) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. @@ -32,17 +32,19 @@ def activation( x: Input tensor to apply activations to activation_type: Sequence of activation functions quantizer: Optional quantizer for quantizing the output + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated output tensor """ assert x.shape[-1] % len(activation_type) == 0 - output = _activation(x, activation_type, quantizer) + output = _activation(x, activation_type, quantizer, act_params) return output -@partial(jax.custom_vjp, nondiff_argnums=(1,)) -def _activation(x, activation_type, quantizer): +@partial(jax.custom_vjp, nondiff_argnums=(1, 3)) +def _activation(x, activation_type, quantizer, act_params): """Internal implementation of activation with custom VJP. This function implements the core activation logic with support for @@ -52,36 +54,42 @@ def _activation(x, activation_type, quantizer): x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Activated tensor """ - _output, _ = _activation_fwd_rule(x, activation_type, quantizer) + _output, _ = _activation_fwd_rule(x, activation_type, quantizer, act_params) return _output -def _activation_fwd_rule(x, activation_type, quantizer): +def _activation_fwd_rule(x, activation_type, quantizer, act_params): """Forward pass rule for activation function. Args: x: Input tensor activation_type: Sequence of activation functions quantizer: Optional quantizer + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. Returns: Tuple of (output, context) for backward pass """ - fwd_output = tex.act_lu(x, activation_type, quantizer) + fwd_output = tex.act_lu(x, activation_type, quantizer, act_params) # This is a no-op for higher-precision tensors fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) -def _activation_bwd_rule(activation_type, ctx, g): +def _activation_bwd_rule(activation_type, act_params, ctx, g): """Backward pass rule for activation function. Args: activation_type: Sequence of activation functions + act_params: Optional activation parameters. Currently used + just for ClampedSwiGLU. ctx: Context from forward pass g: Gradient from upstream @@ -90,7 +98,7 @@ def _activation_bwd_rule(activation_type, ctx, g): """ (x, _) = ctx assert x.dtype == g.dtype - dx = tex.dact_lu(g, x, activation_type) + dx = tex.dact_lu(g, x, activation_type, act_params=act_params) # No quantization is used in this VJP backward, so the output should # always be a NoScaleTensor assert isinstance(dx, NoScaleTensor) diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 093146162..1ce44a2b9 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -209,6 +209,8 @@ def make_swa_mask( segment_pos_kv: jnp.ndarray, window_size: Optional[Tuple[int, int]] = None, dtype: jax.typing.DTypeLike = jnp.float32, + segment_ids_q: jnp.ndarray = None, + segment_ids_kv: jnp.ndarray = None, ): """ Generate a sliding window mask (1 = attend, 0 = masked). @@ -227,6 +229,10 @@ def make_swa_mask( Defaults to None. dtype (jax.typing.DTypeLike, optional): Mask data type. Defaults to jnp.float32. + segment_ids_q (jnp.ndarray): + Query segment id that each token belongs to + segment_ids_kv (jnp.ndarray): + Key/value segment id that each token belongs to Returns: jnp.ndarray: @@ -240,6 +246,18 @@ def make_swa_mask( right_window = jnp.inf if right_window < 0 else right_window pos_q = jnp.expand_dims(segment_pos_q, axis=-1) pos_kv = jnp.expand_dims(segment_pos_kv, axis=-2) + # For Bottom Right Causal Mask (BRCM) + if segment_ids_q is not None and segment_ids_kv is not None: + run_length_q = run_length_fill(segment_ids_q) + run_length_kv = run_length_fill(segment_ids_kv) + run_length_q_exp = jnp.expand_dims(run_length_q, axis=-1) + run_length_kv_exp = jnp.expand_dims(run_length_kv, axis=-2) + bottom_right_inv_swa_mask = ( + run_length_q_exp - pos_q + left_window >= run_length_kv_exp - pos_kv + ) + bottom_right_inv_swa_mask = jnp.expand_dims(bottom_right_inv_swa_mask, axis=-3) + return bottom_right_inv_swa_mask.astype(dtype) + # All other cases other than BRCM inv_swa_mask = (pos_kv >= pos_q - left_window) & (pos_kv <= pos_q + right_window) inv_swa_mask = jnp.expand_dims(inv_swa_mask, axis=-3) return inv_swa_mask.astype(dtype) @@ -420,6 +438,42 @@ def _segment_ids_pos_to_seqlens_offsets_fast_causal_path( ) +def run_length_fill_flattened(segment_ids_flattened) -> jnp.ndarray: + """ + Returns an array of run-lengths of the flattened segment ids + """ + # Example for run_length_fill_flattened: + # Input segment_ids_flattened: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]] + # run_ids: [[0 0 1 1 1 2 3 4 5 5 5 5 5 6 6 6], [0 1 1 2 2 2 3 3 4 4 5 5 5 5 6 6]] + # counts: [[2 3 1 1 1 5 3 0 0 0 0 0 0 0 0 0], [1 2 3 2 2 4 2 0 0 0 0 0 0 0 0 0]] + # Returns segment_ids_run_length_1d: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]] + boundary = jnp.concatenate( + [jnp.broadcast_to(True, (1,)), segment_ids_flattened[1:] != segment_ids_flattened[:-1]] + ) + run_ids = jnp.cumsum(boundary) - 1 + # Each element could, in worst case, start a run + max_runs = segment_ids_flattened.shape[-1] + counts = jnp.bincount(run_ids, length=max_runs) + # Fill in the missing values + segment_ids_run_length_1d = counts[run_ids] + segment_ids_run_length_1d = jnp.where(segment_ids_flattened == 0, 0, segment_ids_run_length_1d) + return segment_ids_run_length_1d + + +def run_length_fill(segment_ids) -> jnp.ndarray: + """ + Returns an array of run-lengths of the segment ids, with shape preserved + """ + # Example for run_length_fill: + # Input segment_ids: [[1 1 2 2 2 0 3 0 4 4 4 4 4 0 0 0], [1 0 0 2 2 2 0 0 3 3 4 4 4 4 0 0]] + # Returns run length: [[2 2 3 3 3 0 1 0 5 5 5 5 5 0 0 0], [1 0 0 3 3 3 0 0 2 2 4 4 4 4 0 0]] + # Flatten all dimension except the last one prior to executing vmap run length + orig_shape = segment_ids.shape + segment_ids_flat = segment_ids.reshape(-1, orig_shape[-1]) + run_length_segment_id_shape = jax.vmap(run_length_fill_flattened, in_axes=0)(segment_ids_flat) + return run_length_segment_id_shape.reshape(orig_shape) + + def _segment_ids_pos_to_seqlens_offsets( segment_ids_q, segment_ids_kv, @@ -443,7 +497,10 @@ def _segment_ids_pos_to_seqlens_offsets( # # This fast path avoids expanding the mask to Q * KV matrix and instead allows us to # examine only O(Q+KV) elements. - if attn_mask_type.is_causal() and window_size is None or window_size == (-1, -1): + # TODO(KshitijLakhani): Try exercising the fast path for BRCM as well + if (attn_mask_type.is_causal() and window_size is None) or ( + window_size == (-1, -1) and not attn_mask_type.is_bottom_right() + ): return _segment_ids_pos_to_seqlens_offsets_fast_causal_path( segment_ids_q, segment_ids_kv, segment_pos_q, segment_pos_kv, max_segments_per_seq ) @@ -459,8 +516,41 @@ def _segment_ids_pos_to_seqlens_offsets( segment_ids_kv, lambda x, y: jnp.equal(x, y) * x, ) + # TE JAX Attn expects the THD segments to have q_token <= kv_tokens so that a correct cross-attn type BRCM can be applied attn_mask = segment_mask - if attn_mask_type.is_causal(): + if attn_mask_type.is_bottom_right(): + run_length_out_q = run_length_fill(segment_ids_q) + run_length_out_kv = run_length_fill(segment_ids_kv) + # Example for brcm: + # run_length_out_q: [3 3 3 0 4 4 4 4] + # segment_pos_q: [0 1 2 3 0 1 2 3] + # segment_ids_q: [1 1 1 0 2 2 2 2] + # run_length_out_kv: [4 4 4 4 0 0 10 10 10 10 10 10 10 10 10 10] + # segment_pos_kv: [0 1 2 3 4 5 0 1 2 3 4 5 6 7 8 9] + # segment_ids_kv: [1 1 1 1 0 0 2 2 2 2 2 2 2 2 2 2] + # brcm: [[[1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] + # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] + # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] + # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1] + # [1 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] + # [1 1 0 0 0 0 1 1 1 1 1 1 1 1 0 0] + # [1 1 1 0 0 0 1 1 1 1 1 1 1 1 1 0] + # [1 1 1 1 0 0 1 1 1 1 1 1 1 1 1 1]]] + # attn_mask(noswa):[[[1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0] + # [1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0] + # [1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0] + # [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 0 0 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0] + # [0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1]]] + bottom_right_causal_mask = make_attention_mask( + run_length_out_q - segment_pos_q, + run_length_out_kv - segment_pos_kv, + jnp.less_equal, + ) + attn_mask = jnp.logical_and(segment_mask, bottom_right_causal_mask) + elif attn_mask_type.is_causal(): causal_mask = make_attention_mask( segment_pos_q, segment_pos_kv, @@ -468,7 +558,19 @@ def _segment_ids_pos_to_seqlens_offsets( ) attn_mask = jnp.logical_and(segment_mask, causal_mask) - swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + # TODO(KshitijLakhani): Evaluate if swa_mask is needed to procure seqlen and offsets + swa_mask = ( + make_swa_mask( + segment_pos_q, + segment_pos_kv, + window_size, + dtype=jnp.bool, + segment_ids_q=segment_ids_q, + segment_ids_kv=segment_ids_kv, + ) + if attn_mask_type.is_bottom_right() + else make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + ) attn_mask = jnp.logical_and(attn_mask, swa_mask) attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) @@ -1125,5 +1227,4 @@ def fused_attn( context_parallel_axis=context_parallel_axis, context_checkpoint_name=context_checkpoint_name, ) - return output diff --git a/transformer_engine/jax/cpp_extensions/__init__.py b/transformer_engine/jax/cpp_extensions/__init__.py index ef8d76cd0..c0285e157 100644 --- a/transformer_engine/jax/cpp_extensions/__init__.py +++ b/transformer_engine/jax/cpp_extensions/__init__.py @@ -3,6 +3,7 @@ # See LICENSE for license information. """Python interface for c++ extensions""" from .activation import * +from .amax import * from .attention import * from .normalization import * from .quantization import * diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index ef2643359..df148265d 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -7,18 +7,17 @@ from typing import Sequence, Union, Callable, Optional, Tuple import operator from functools import reduce, partial -from packaging import version +from dataclasses import dataclass import jax import jax.numpy as jnp -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule +from jax import dtypes, ffi +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.sharding import PartitionSpec +import numpy as np import transformer_engine_jax from transformer_engine_jax import NVTE_Activation_Type - from .base import BasePrimitive, register_primitive from .misc import ( jax_dtype_to_te_dtype, @@ -30,7 +29,7 @@ should_apply_1x_fused_dbias_war_for_arch_l_100, NamedSharding, ) -from .quantization import _jax_dbias, _quantize_dbias_impl +from .quantization import _jax_dbias, quantize, quantize_dbias, _quantize_dbias_impl, AmaxScope from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( @@ -40,10 +39,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports __all__ = ["act_lu", "dact_lu", "quantize_dact_dbias"] @@ -59,17 +54,87 @@ ("quick_gelu", "linear"): NVTE_Activation_Type.QGEGLU, ("squared_relu",): NVTE_Activation_Type.SRELU, ("squared_relu", "linear"): NVTE_Activation_Type.SREGLU, + ("clamped_silu", "clamped_linear"): NVTE_Activation_Type.CLAMPED_SWIGLU, } -def _convert_to_activation_function(fn_or_string): +@dataclass(frozen=True) +class ClampedSwigluParams: + """Parameters for the Clamped SwiGLU activation function + used in GPT OSS.""" + + limit: float = 7.0 + alpha: float = 1.702 + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work. + + Returns: + int: Hash value of the dataclass instance. + """ + return hash((self.limit, self.alpha)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"limit": np.float32(self.limit), "alpha": np.float32(self.alpha)} + + +@dataclass(frozen=True) +class ActivationParams: + """Parameters for various activation functions. + Currently only Clamped SwiGLU activation has parameters. + """ + + clamped_swiglu: ClampedSwigluParams = ClampedSwigluParams() + + @staticmethod + def create(activation_type, **kwargs): + """Factory method to create ActivationParams based on activation_type.""" + CLAMPED_ACTIVATION_TYPES = { + ("clamped_silu", "clamped_linear"), + "clamped_silu", + "clamped_linear", + } + if activation_type in CLAMPED_ACTIVATION_TYPES: + return ActivationParams(ClampedSwigluParams(**kwargs)) + return ActivationParams() # Default params for activations without parameters + + def __hash__(self): + """Custom hash function to ensure dataclass is hashable for jax jit to work""" + return hash((self.clamped_swiglu,)) + + def to_ffi_lowering_dict(self): + """Convert the activation parameters to a dictionary format for FFI lowering. + Returns: + dict: A dictionary representation of the activation parameters consumable by + XLA FFI bindings for activation functions. + """ + return {"clamped_swiglu": self.clamped_swiglu.to_ffi_lowering_dict()} + + +def _convert_to_activation_function(fn_or_string, act_params: ActivationParams): """Convert a string to an activation function.""" if fn_or_string == "linear": return lambda x: x + if fn_or_string == "clamped_linear": + # This function is used for ClampedSwiGLU + # used in GPT OSS where the gates are not only clamped + # but also shifted by +1 + limit = act_params.clamped_swiglu.limit + return lambda x: jnp.clip(x, min=-limit, max=limit) + 1 if fn_or_string == "quick_gelu": return lambda x: jax.nn.sigmoid(1.702 * x) * x if fn_or_string == "squared_relu": return lambda x: reduce(operator.mul, [jax.nn.relu(x), jax.nn.relu(x)]) + if fn_or_string == "clamped_silu": + limit = act_params.clamped_swiglu.limit + alpha = act_params.clamped_swiglu.alpha + return lambda x: jax.nn.sigmoid(alpha * jnp.minimum(x, limit)) * jnp.minimum(x, limit) if isinstance(fn_or_string, str): return getattr(jax.nn, fn_or_string) if callable(fn_or_string): @@ -85,14 +150,18 @@ class ActLuPrimitive(BasePrimitive): name = "te_act_lu_ffi" multiple_results = True impl_static_args = ( - 2, 3, 4, 5, 6, 7, 8, - ) # out_dtype, act_enum, act_len, scaling_mode, is_2x, scale_dtype, is_outer + 9, + 10, + 11, + 12, + 13, + ) # out_dtype, act_enum, act_len, scaling_mode, quantize_layout, scale_dtype, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer inner_primitive = None outer_primitive = None @@ -100,22 +169,31 @@ class ActLuPrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, *, out_dtype, act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ te_act_lu_p abstract """ - del act_enum + del act_enum, act_params, amax_scope, transpose_batch_sequence + assert ( + not output_amax_when_no_scaling or scaling_mode == ScalingMode.NO_SCALING.value + ), f"scaling_mode = {scaling_mode}" dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 assert x_aval.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x_aval.shape} and act_len {act_len}" @@ -125,6 +203,13 @@ def abstract( "Current tensor scaling is not yet supported for fused activation and quantization." " Please do activation in higher-precision then quantize with current tensor scaling." ) + assert not ScalingMode(scaling_mode).is_nvfp4_scaling, ( + "NVFP4 block scaling is not yet supported for fused activation and quantization." + " Please do activation in higher-precision then quantize with current tensor scaling." + ) + assert ( + not quantize_layout.is_colwise_only + ), "Fused activation with colwise-only quantization is not supported." out_shape = (*x_aval.shape[:-2], x_aval.shape[-1]) # Exclude act dim out_aval = x_aval.update(shape=out_shape, dtype=out_dtype) @@ -134,7 +219,7 @@ def abstract( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(out_shape, is_padded=not is_outer, flatten_axis=-1) - if not is_2x: + if quantize_layout.is_rowwise_only: out_shape = (1,) colwise_scale_inv_shape = (1,) colwise_out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) @@ -150,25 +235,42 @@ def lowering( ctx, x, scale, + amax, *, out_dtype, act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ te_gated_act_lu_p lowering rules """ - del out_dtype, scale_dtype, act_len, is_outer - x_aval, scale_aval = ctx.avals_in + del out_dtype, scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence + x_aval, scale_aval, amax_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval.dtype == jnp.float32 - out = ffi.ffi_lowering(ActLuPrimitive.name)( - ctx, x, scale, act_enum=act_enum, scaling_mode=scaling_mode.value, is_2x=is_2x + out = ffi.ffi_lowering( + ActLuPrimitive.name, + operand_output_aliases={2: 4}, # donate amax buffer to updated_amax + )( + ctx, + x, + scale, + amax, + act_enum=act_enum, + scaling_mode=scaling_mode.value, + quantize_layout=quantize_layout.value.value, + act_params=act_params.to_ffi_lowering_dict(), + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return out @@ -176,12 +278,17 @@ def lowering( def impl( x, scale, + amax, out_dtype, act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ @@ -194,12 +301,17 @@ def impl( ActLuPrimitive.inner_primitive.bind( x, scale, + amax, out_dtype=out_dtype, act_enum=act_enum, act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=False, ) ) @@ -210,7 +322,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -226,18 +338,21 @@ def batcher( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ to describe batch rules for vmap """ - del act_len, is_outer check_valid_batch_dims(batch_dims) assert ActLuPrimitive.outer_primitive is not None - x, scale = batched_args - x_bdim, scale_bdim = batch_dims + x, scale, amax = batched_args + x_bdim, scale_bdim, _ = batch_dims amax_bdim = scale_bdim out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim @@ -245,11 +360,18 @@ def batcher( ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=out_dtype, act_enum=act_enum, + act_len=act_len, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -260,8 +382,12 @@ def infer_sharding_from_operands( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, @@ -273,6 +399,10 @@ def infer_sharding_from_operands( act_enum, scale_dtype, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ) # Unused. x_spec = get_padded_spec(arg_infos[0]) @@ -281,7 +411,7 @@ def infer_sharding_from_operands( out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: @@ -298,7 +428,7 @@ def infer_sharding_from_operands( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -323,21 +453,25 @@ def partition( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, result_infos, ): - del result_infos, is_outer # Unused. + del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) out_spec = (*x_spec[:-2], x_spec[-1]) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="ActLuPrimitive.out") - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_out_spec = multidim_transpose(out_spec, transpose_axis=-1) else: @@ -354,7 +488,10 @@ def partition( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: + assert not ScalingMode( + scaling_mode + ).is_colwise_transposed, "Transpose layout scaling modes are not supported here yet" colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -374,25 +511,40 @@ def partition( amax_sharding, ) - def sharded_impl(x, scale): - local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, local_amax = ( - ActLuPrimitive.impl( - x, - scale, - out_dtype=out_dtype, - act_enum=act_enum, - act_len=act_len, - scaling_mode=scaling_mode, - is_2x=is_2x, - scale_dtype=scale_dtype, - is_outer=True, - ) + def sharded_impl(x, scale, amax): + ( + local_x, + local_colwise_x, + local_scale_inv, + local_colwise_scale_inv, + local_updated_amax, + ) = ActLuPrimitive.impl( + x, + scale, + amax, + out_dtype=out_dtype, + act_enum=act_enum, + act_len=act_len, + scaling_mode=scaling_mode, + quantize_layout=quantize_layout, + scale_dtype=scale_dtype, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=True, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, out_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return ( local_x, @@ -410,52 +562,59 @@ def shardy_sharding_rule( act_enum, act_len, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, value_types, result_types, ): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, act_enum, act_len, scale_dtype, is_outer, mesh, result_types - prefix = "ActLuPrimitive_" - x_rank = len(value_types[0].shape) + del ( + out_dtype, + act_enum, + act_len, + scale_dtype, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, + is_outer, + mesh, + result_types, + ) + prefix = "ActLu" + input_shape = value_types[0].shape + output_shape = input_shape[:-2] + input_shape[-1:] + # Here we pass len of output so that the scales are propagated correctly scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - x_rank - 1, unique_var=prefix + "x", flatten_axis=-2 + output_shape, unique_var=prefix, flatten_axis=-1, q_layout=quantize_layout ) - x_axes = scale_rules.input_spec + (prefix + f"x{x_rank - 1}",) - out = (*x_axes[:-2], x_axes[-1]) - scale_inv = scale_rules.rowwise_rule - - colwise_out = (prefix + "out_colwise",) - colwise_scale_inv = (prefix + "scale_inv_colwise",) - if is_2x: - colwise_scale_inv = scale_rules.colwise_rule - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple( - multidim_transpose(x_axes, static_axis_boundary=-1, transpose_axis=-2) - ) - else: - colwise_out = out - - # amax is always a unit tensor. - amax = (prefix + "amax",) + # Correct the input spec with act dim + input_spec = scale_rules.input_spec + input_spec = input_spec[:-1] + (prefix + "_act_dim",) + input_spec[-1:] + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) return SdyShardingRule( + (tuple(input_spec), scale, amax), ( - x_axes, - ("…1",), + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, ), - (out, colwise_out, scale_inv, colwise_scale_inv, amax), + **scale_rules.factor_sizes, ) register_primitive(ActLuPrimitive) -# TODO(Jeremy): replace is_2x with q_layout class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): """ DActLu DBias Cast Transpose Primitive @@ -463,8 +622,8 @@ class BaseDActLuDBiasQuantizePrimitive(BasePrimitive): name = "te_dact_dbias_quantize_ffi" multiple_results = True - # out_dtype, scaling_mode, is_2x, scale_dtype, is_dbias, act_enum, act_len, is_outer - impl_static_args = (3, 4, 5, 6, 7, 8, 9, 10) + # out_dtype, scaling_mode, quantize_layout, scale_dtype, is_dbias, act_enum, act_len, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling, is_outer + impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -473,20 +632,25 @@ def abstract( dz_aval, x_aval, scale_aval, + amax_aval, *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ te_dact_dbias_quantize_p abstract """ - del act_enum + del act_enum, act_params, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling dz_dtype = dtypes.canonicalize_dtype(dz_aval.dtype) assert dz_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_dtype @@ -495,6 +659,7 @@ def abstract( f" {x_aval.shape} and act_len {act_len}" ) assert scale_aval.dtype == jnp.float32 + assert amax_aval.dtype == jnp.float32 assert scaling_mode != ScalingMode.CURRENT_TENSOR_SCALING.value, ( "Current tensor scaling is not supported for fused dact and quantization. Please do" @@ -515,7 +680,7 @@ def abstract( rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=-2) - if is_2x: + if quantize_layout.is_rowwise_colwise: if ScalingMode(scaling_mode).is_tensor_scaling(): colwise_out_shape = multidim_transpose(out_shape, transpose_axis=-2) else: @@ -537,7 +702,7 @@ def abstract( jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), scaling_mode, - is_2x, + quantize_layout.value, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -574,33 +739,51 @@ def lowering( dz, x, scale, + amax, *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ te_dact_dbias_quantize_p lowering rules """ - del out_dtype, scale_dtype, act_len, is_outer - dz_aval, x_aval, scale_aval = ctx.avals_in + del ( + out_dtype, + scale_dtype, + act_len, + is_outer, + amax_scope, + transpose_batch_sequence, + ) + dz_aval, x_aval, scale_aval, amax_aval = ctx.avals_in assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert x_aval.dtype == dz_aval.dtype - assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(BaseDActLuDBiasQuantizePrimitive.name)( + assert scale_aval.dtype == amax_aval.dtype == jnp.float32 + return ffi.ffi_lowering( + BaseDActLuDBiasQuantizePrimitive.name, + operand_output_aliases={3: 4}, # donate amax buffer to updated_amax + )( ctx, dz, x, scale, + amax, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, is_dbias=is_dbias, act_enum=int(act_enum), + act_params=act_params.to_ffi_lowering_dict(), + output_amax_when_no_scaling=output_amax_when_no_scaling, ) @staticmethod @@ -608,13 +791,18 @@ def impl( dz, x, scale, + amax, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ @@ -627,13 +815,18 @@ def impl( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=False, ) ) @@ -644,7 +837,7 @@ def impl( scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -657,21 +850,24 @@ def batcher( *, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ to describe batch rules for vmap """ - del is_outer check_valid_batch_dims(batch_dims) assert BaseDActLuDBiasQuantizePrimitive.outer_primitive is not None - dz, x, scale = batched_args - _, x_bdim, scale_bdim = batch_dims + dz, x, scale, amax = batched_args + _, x_bdim, scale_bdim, _ = batch_dims out_bdims = ( x_bdim, # rowwise output @@ -686,13 +882,19 @@ def batcher( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -701,18 +903,23 @@ def batcher( def infer_sharding_from_operands( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, result_infos, ): - del out_dtype, result_infos, act_enum - del scale_dtype, act_len, is_outer + del out_dtype, result_infos, act_enum, act_params, output_amax_when_no_scaling + del scale_dtype, act_len, is_outer, amax_scope, transpose_batch_sequence + x_spec = get_padded_spec(arg_infos[1]) scale_spec = get_padded_spec(arg_infos[2]) @@ -723,7 +930,7 @@ def infer_sharding_from_operands( out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: @@ -749,7 +956,7 @@ def infer_sharding_from_operands( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -776,11 +983,15 @@ def infer_sharding_from_operands( def partition( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, @@ -794,7 +1005,7 @@ def partition( mesh, PartitionSpec(*x_spec), desc="BaseDActLuDBiasQuantizePrimitive.out" ) - if is_2x: + if quantize_layout.is_rowwise_colwise: if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: colwise_x_spec = multidim_transpose(x_spec, transpose_axis=-2) else: @@ -820,7 +1031,7 @@ def partition( elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec - if is_2x: + if quantize_layout.is_rowwise_colwise: colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( @@ -848,19 +1059,24 @@ def partition( dbias_sharding, ) - def sharded_impl(dz, x, scale): - (out, colwise_out, scale_inv, colwise_scale_inv, local_amax, local_dbias) = ( + def sharded_impl(dz, x, scale, amax): + (out, colwise_out, scale_inv, colwise_scale_inv, local_updated_amax, local_dbias) = ( BaseDActLuDBiasQuantizePrimitive.impl( dz, x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, is_dbias=is_dbias, act_enum=act_enum, act_len=act_len, + act_params=act_params, + output_amax_when_no_scaling=output_amax_when_no_scaling, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, is_outer=True, ) ) @@ -870,9 +1086,15 @@ def sharded_impl(dz, x, scale): global_dbias = local_dbias if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, x_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return out, colwise_out, scale_inv, colwise_scale_inv, global_updated_amax, global_dbias @@ -882,39 +1104,61 @@ def sharded_impl(dz, x, scale): def shardy_sharding_rule( out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, is_dbias, act_enum, act_len, + act_params, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, value_types, result_types, ): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, scale_dtype, act_enum, act_len, is_outer, mesh, result_types - prefix = "BaseDActLuDBiasQuantizePrimitive_" + + del ( + out_dtype, + scale_dtype, + act_enum, + act_len, + act_params, + is_outer, + output_amax_when_no_scaling, + mesh, + result_types, + amax_scope, + transpose_batch_sequence, + ) + + prefix = "DActLuDBias_" + # get sharding rules base on the input shape scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[1].shape), unique_var=prefix + "x", flatten_axis=-2 + value_types[1].shape, + unique_var=prefix, + flatten_axis=-2, + q_layout=quantize_layout, ) - x_axes = scale_rules.input_spec - dz_axes = (*x_axes[:-2], x_axes[-1]) - out = x_axes - colwise_out = (prefix + "out_colwise",) - if is_2x: - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=-2)) - else: - colwise_out = out - dbias = x_axes[-2:] if is_dbias else (prefix + "dbias",) - amax = (prefix + "amax",) + input_spec = scale_rules.input_spec + dz_spec = (*input_spec[:-2], input_spec[-1]) + dbias = input_spec[-2:] if is_dbias else (prefix + "_dbias",) + amax = (prefix + "_amax",) + scale = (prefix + "_scale",) return SdyShardingRule( - (dz_axes, x_axes, ("…2",)), - (out, colwise_out, scale_rules.rowwise_rule, scale_rules.colwise_rule, amax, dbias), + (tuple(dz_spec), tuple(input_spec), scale, amax), + ( + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, + dbias, + ), + **scale_rules.factor_sizes, ) @@ -929,20 +1173,22 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: +def _jax_act_lu( + inputs, activation_type, quantizer=None, act_params: Optional[ActivationParams] = None +) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert inputs.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {inputs.shape} and act_len {act_len}" ) - x = jnp.split(inputs, act_len, axis=-2) acts = [] for idx, act_fn in enumerate(activation_type): - x_i = _convert_to_activation_function(act_fn)(x[idx]) + x_i = _convert_to_activation_function(act_fn, act_params)(x[idx]) acts.append(x_i) x = reduce(operator.mul, acts) x = jnp.squeeze(x, axis=-2) @@ -957,10 +1203,12 @@ def _jax_quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, ): """ JAX implementation of dact_lu and dbias with optional quantization """ + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" @@ -968,7 +1216,8 @@ def _jax_quantize_dact_dbias( ) _, vjp_func = jax.vjp( - partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) + partial(_jax_act_lu, activation_type=activation_type, act_params=act_params), + x.astype(jnp.float32), ) # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) @@ -991,6 +1240,10 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -999,6 +1252,7 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: If quantizer is None: @@ -1006,59 +1260,100 @@ def act_lu( If quantizer is provided: A ScaledTensor containing the quantized activated input. """ + # TODO(Phuong): remove the output_amax_when_no_scaling exposure by introducing _act_lu_impl() + # Do the same with dact_dbias_quantize() and layernorm_fwd() act_type_id = ActivationEnum[activation_type].value act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - + act_params = act_params if act_params is not None else ActivationParams() if not ActLuPrimitive.enabled(): - return _jax_act_lu(x, activation_type, quantizer) + act_out = _jax_act_lu(x, activation_type, act_params=act_params) + assert ( + act_out.data.dtype == x.dtype + ), f"JAX activation output dtype {act_out.data.dtype} must match input dtype {x.dtype}" + if quantizer is None: + return act_out - # TE/common does not support colwise-only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: - return _jax_act_lu(x, activation_type, quantizer) + return quantize( + act_out, + quantizer=quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + # TE/common does not support colwise-only quantization yet + if quantizer is not None and quantizer.q_layout.is_colwise_only: + return _jax_act_lu(x, activation_type, quantizer, act_params) # TE/common does not support 2x quantization for DelayedScaling yet war_output = try_apply_delayed_scaling_2x_war( - f=act_lu, x=x, activation_type=activation_type, quantizer=quantizer + f=act_lu, + x=x, + activation_type=activation_type, + quantizer=quantizer, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) if war_output is not None: return war_output scale = jnp.empty((1,), jnp.float32) output_shape = (*x.shape[:-2], x.shape[-1]) + amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,) if quantizer is None: - out, _, _, _, _ = ActLuPrimitive.outer_primitive.bind( + out, _, _, _, updated_amax = ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=x.dtype, act_enum=act_type_id, act_len=act_len, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) out = out.reshape(output_shape) + # TODO(Phuong): ScaledTensorFactory to create NoScaledTensor out = NoScaleTensor( data=out, - amax=None, + amax=updated_amax if output_amax_when_no_scaling else None, ) return out - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, + ) + assert ( + out.data.dtype == x.dtype + ), f"Activation output dtype {out.data.dtype} must match input dtype {x.dtype}" + out = quantize( + out, + quantizer=quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out - if isinstance(quantizer, DelayedScaleQuantizer): scale = quantizer.scale @@ -1071,12 +1366,17 @@ def act_lu( ) = ActLuPrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=quantizer.q_dtype, act_enum=act_type_id, act_len=act_len, scaling_mode=quantizer.scaling_mode.value, - is_2x=quantizer.is_2x2x(), + quantize_layout=quantizer.q_layout, scale_dtype=quantizer.get_scale_dtype(), + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) @@ -1100,6 +1400,10 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1116,35 +1420,56 @@ def quantize_dact_dbias( - The gradient of the activation with respect to the input. - The gradient of the activation with respect to the bias. """ - + act_params = act_params if act_params is not None else ActivationParams() act_len = len(activation_type) assert x.shape[-2] == act_len, ( "activation input should be replicated by act_len in the -2 axis, got input shape" f" {x.shape} and act_len {act_len}" ) - scale = jnp.empty((), jnp.float32) + scale = jnp.empty((1,), jnp.float32) + amax = jnp.zeros((1,), jnp.float32) # need to init with zero and shape=(1,) act_type_id = ActivationEnum[activation_type] PrimitiveClass = DActLuDBiasQuantizePrimitive if is_dbias else DActLuQuantizePrimitive if not PrimitiveClass.enabled() or ( - quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE + quantizer is not None and quantizer.q_layout.is_colwise_only ): - return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, quantizer) + if quantizer is None: + return _jax_quantize_dact_dbias(dz, x, activation_type, is_dbias, act_params=act_params) + dact_out, _ = _jax_quantize_dact_dbias( + dz, x, activation_type, is_dbias=False, act_params=act_params + ) + assert ( + dact_out.data.dtype == x.dtype + ), f"JAX dact output dtype {dact_out.data.dtype} must match input dtype {x.dtype}" + return quantize_dbias( + dact_out, + quantizer, + is_dbias=is_dbias, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) if quantizer is None: - output, _, _, _, _, _ = PrimitiveClass.outer_primitive.bind( + output, _, _, _, updated_amax, _ = PrimitiveClass.outer_primitive.bind( dz, x, scale, + amax, # outputs float32 for dbias accumulation out_dtype=(jnp.float32 if is_dbias else x.dtype), # default value for no scaling, TE/common ignore this value when scale is unset scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, # unused + quantize_layout=QuantizeLayout.ROWWISE, # unused scale_dtype=jnp.float32, # unused is_dbias=False, act_enum=act_type_id, act_len=act_len, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) output = output.astype(x.dtype) @@ -1154,17 +1479,30 @@ def quantize_dact_dbias( output = NoScaleTensor( data=output, - amax=None, + amax=updated_amax if output_amax_when_no_scaling else None, ) return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet if should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias=is_dbias, quantizer=quantizer): out = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type, + quantizer=None, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return _quantize_dbias_impl( - out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out, + quantizer, + is_dbias=True, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) is_gated = act_len == 2 @@ -1178,20 +1516,37 @@ def quantize_dact_dbias( is_dbias=is_dbias, quantizer=quantizer, flatten_axis=-2, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) if war_output is not None: return war_output - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( dz=dz, x=x, activation_type=activation_type, quantizer=None, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) out, dbias = _quantize_dbias_impl( - out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out, + is_dbias=is_dbias, + quantizer=quantizer, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, dbias @@ -1201,10 +1556,21 @@ def quantize_dact_dbias( # TE/common dact_dbias_quantize does not support gated act yet if is_dbias and is_gated: dgated = dact_lu( - dz.astype(jnp.float32), x.astype(jnp.float32), activation_type=activation_type + dz.astype(jnp.float32), + x.astype(jnp.float32), + activation_type=activation_type, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) out, dbias = _quantize_dbias_impl( - dgated, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + dgated, + quantizer, + is_dbias=True, + dq_dtype=x.dtype, + flatten_axis=-2, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, dbias @@ -1219,18 +1585,23 @@ def quantize_dact_dbias( dz, x, scale, + amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=quantizer.is_2x2x(), + quantize_layout=quantizer.q_layout, scale_dtype=quantizer.get_scale_dtype(), is_dbias=is_dbias, act_enum=act_type_id, act_len=act_len, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) # For DelayedScaling transpose, the scale buffer is shared for both rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise: colwise_scale_inv = rowwise_scale_inv quantizer.update(updated_amax) @@ -1255,6 +1626,10 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, + act_params: Optional[ActivationParams] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1268,11 +1643,16 @@ def dact_lu( Returns: The gradient of the activation with respect to the input. """ + act_params = act_params if act_params is not None else ActivationParams() output, _ = quantize_dact_dbias( dz=dz, x=x, activation_type=activation_type, is_dbias=False, quantizer=quantizer, + act_params=act_params, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) return output diff --git a/transformer_engine/jax/cpp_extensions/amax.py b/transformer_engine/jax/cpp_extensions/amax.py new file mode 100644 index 000000000..2f3bc402e --- /dev/null +++ b/transformer_engine/jax/cpp_extensions/amax.py @@ -0,0 +1,420 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""JAX/TE custom ops for amax calculation""" +from enum import Enum + + +import jax +import jax.numpy as jnp +from jax import dtypes, ffi +from jax.experimental.custom_partitioning import SdyShardingRule +from jax.sharding import PartitionSpec + +from .base import BasePrimitive, register_primitive +from .misc import ( + get_padded_spec, + NamedSharding, +) +from ..sharding import ( + global_mesh_resource, + lax_paral_op, +) +from ..quantize import ( + get_wgrad_sign_vector, + get_sign_from_vector, +) + + +__all__ = ["AmaxScope", "calculate_amax", "calculate_post_rht_amax"] + + +class AmaxScope(Enum): + """ + Amax Scope Enum + """ + + LOCAL = 1 + TPSP = 2 + FSDP = 3 + + def all_reduce_amax_along_TPSP_and_FSDP(self, amax, data_spec, transpose_batch_sequence, mesh): + """Reduce the amax based on its scope""" + gmesh = global_mesh_resource() + sequence_dim = 0 if transpose_batch_sequence else 1 + # Run AR across TPSP only when tensor-sequence is detected in the input spec + if self is AmaxScope.TPSP and data_spec[sequence_dim] == gmesh.tpsp_resource: + return lax_paral_op(amax, jax.lax.pmax, gmesh.tpsp_resource, mesh) + # Run AR across FSDP + if self is AmaxScope.FSDP: + return lax_paral_op(amax, jax.lax.pmax, gmesh.fsdp_resource, mesh) + return amax + + +class AmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning + """ + + name = "jax_local_amax" + multiple_results = False + impl_static_args = ( + 1, + 2, + ) # amax_scope, transpose_batch_sequence + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + transpose_batch_sequence, + ): + """ + amax calcuation abstract + """ + del amax_scope, transpose_batch_sequence + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + out_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + return out_aval + + @staticmethod + def impl( + x, + amax_scope, + transpose_batch_sequence, + ): + """ + amax calcuation implementation + """ + del amax_scope, transpose_batch_sequence + amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32).reshape((1,)) + return amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + transpose_batch_sequence, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del (amax_scope, transpose_batch_sequence, arg_infos, result_infos) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding + + @staticmethod + def partition( + amax_scope, + transpose_batch_sequence, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="AmaxCalculation.amax_sharding", + ) + + def sharded_impl(x): + amax = AmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh + ) + + return amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, amax_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule(amax_scope, transpose_batch_sequence, mesh, value_types, result_types): + """ + amax calcuation shardy_sharding_rule + """ + del amax_scope, transpose_batch_sequence, mesh, result_types + prefix = "AmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_spec = (f"{prefix}_amax",) + return SdyShardingRule((input_spec,), (output_spec,)) + + +register_primitive(AmaxCalculationPrimitive, outer_only=True) + + +class RHTAmaxCalculationPrimitive(BasePrimitive): + """ + Amax Calculation Primitive with custom_partitioning for calculating regular and post-Random Hadamard Transform (RHT) amax using TE's fused kernels. + """ + + name = "te_rht_amax_ffi" + multiple_results = True + impl_static_args = ( + 1, # amax_scope + 2, # transpose_batch_sequence + 3, # rht_matrix_random_sign_mask_t + 4, # produce_regular_amax + 5, # flatten_axis + ) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + x_aval, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + amax calcuation abstract + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ) + + dtype = dtypes.canonicalize_dtype(x_aval.dtype) + assert dtype in [jnp.bfloat16], f"RHT requires input to be bfloat16, but got {dtype}" + + amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + post_rht_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + + return amax_aval, post_rht_amax_aval + + @staticmethod + def lowering( + ctx, + x, + *, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + te_dbias_quantize_p lowering rules + """ + del amax_scope, transpose_batch_sequence + (x_aval,) = ctx.avals_in + assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + + flatten_axis = flatten_axis if flatten_axis >= 0 else flatten_axis + len(x_aval.shape) + assert 0 < flatten_axis < len(x_aval.shape), "Flatten axis out of bounds!" + + return ffi.ffi_lowering( + RHTAmaxCalculationPrimitive.name, + )( + ctx, + x, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + + @staticmethod + def impl( + x, + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + ): + """ + amax calcuation implementation + """ + assert RHTAmaxCalculationPrimitive.inner_primitive is not None + ( + amax, + post_rht_amax, + ) = RHTAmaxCalculationPrimitive.inner_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + return amax, post_rht_amax + + @staticmethod + def infer_sharding_from_operands( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation infer_sharding_from_operands + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + arg_infos, + result_infos, + ) # Unused. + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="RHTAmaxCalculationPrimitive.out_sharding", + ) + return amax_sharding, amax_sharding + + @staticmethod + def partition( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + arg_infos, + result_infos, + ): + """ + amax calcuation partition + """ + del result_infos + x_spec = get_padded_spec(arg_infos[0]) + amax_sharding = NamedSharding( + mesh, + PartitionSpec(None), + desc="RHTAmaxCalculationPrimitive.amax_sharding", + ) + out_shardings = (amax_sharding, amax_sharding) + + def sharded_impl(x): + amax, post_rht_amax = RHTAmaxCalculationPrimitive.impl( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=rht_matrix_random_sign_mask_t, + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + amax, x_spec, transpose_batch_sequence, mesh + ) + post_rht_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + post_rht_amax, x_spec, transpose_batch_sequence, mesh + ) + + return amax, post_rht_amax + + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + value_types, + result_types, + ): + """ + amax calcuation shardy_sharding_rule + """ + del ( + amax_scope, + transpose_batch_sequence, + rht_matrix_random_sign_mask_t, + produce_regular_amax, + flatten_axis, + mesh, + result_types, + ) + prefix = "RHTAmaxCal" + input_spec = tuple(f"{prefix}_{i}" for i in range(len(value_types[0].shape))) + output_amax_spec = (f"{prefix}_amax",) + output_post_rht_amax_spec = (f"{prefix}_post_rht_amax",) + return SdyShardingRule((input_spec,), (output_amax_spec, output_post_rht_amax_spec)) + + +register_primitive(RHTAmaxCalculationPrimitive) + + +def calculate_amax(x: jnp.ndarray, amax_scope: AmaxScope, transpose_batch_sequence: bool): + """ + Compute the maximum absolute value (amax) of the input tensor. + """ + assert AmaxCalculationPrimitive.outer_primitive is not None + return AmaxCalculationPrimitive.outer_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + + +def calculate_post_rht_amax( + x: jnp.ndarray, + amax_scope: AmaxScope, + transpose_batch_sequence: bool, + produce_regular_amax: bool, + flatten_axis: int, +): + """Compute the post-Random Hadamard Transform (RHT) amax of the input tensor, and optionally the regular amax. + + Args: + x: Input tensor. + amax_scope: The scope for amax reduction (local, TPSP, or FSDP). + transpose_batch_sequence: Whether the input tensor has its batch and sequence dimensions transposed. + produce_regular_amax: Whether to compute and return the regular amax alongside the post-RHT amax. + flatten_axis: The axis at which to flatten the input tensor before applying RHT. + Returns: + A tuple containing: + - The regular amax if `produce_regular_amax` is True, otherwise None. + - The post-RHT amax. + """ + amax, post_rht_amax = RHTAmaxCalculationPrimitive.outer_primitive.bind( + x, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + rht_matrix_random_sign_mask_t=get_sign_from_vector(get_wgrad_sign_vector()), + produce_regular_amax=produce_regular_amax, + flatten_axis=flatten_axis, + ) + + if produce_regular_amax: + return amax, post_rht_amax + return None, post_rht_amax diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 45d3d8b59..ba6d01a9f 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,12 +9,12 @@ import warnings from dataclasses import dataclass, replace from functools import partial, reduce -from typing import Optional, Tuple from packaging import version +from typing import Optional, Tuple import jax import jax.numpy as jnp -from jax import dtypes, lax +from jax import dtypes, lax, ffi from jax.sharding import PartitionSpec, NamedSharding if version.parse(jax.__version__) >= version.parse("0.5.0"): from jax.experimental.custom_partitioning import SdyShardingRule @@ -53,12 +53,6 @@ ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - - __all__ = [ "FusedAttnHelper", "fused_attn_fwd", @@ -1818,6 +1812,9 @@ def partition(config, mesh, arg_infos, result_infos): ) arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[4] = seed_sharding + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -1854,7 +1851,7 @@ def ring_attn_fwd_impl( # RNG shape should be the shared shape. This is unused for ring attention as we do not # support dropout currently. - rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:]) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): @@ -2025,7 +2022,13 @@ def partition(config, mesh, arg_infos, result_infos): dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) helper = _FusedAttnCPWithP2PHelper(mesh, config) @@ -2299,6 +2302,9 @@ def partition(config, mesh, arg_infos, result_infos): ) arg_shardings = [arg_i.sharding for arg_i in arg_infos] arg_shardings[4] = seed_sharding + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) @@ -2340,7 +2346,7 @@ def fwd_impl( # RNG shape should be the shared shape. This is unused for ring attention as we do not # support dropout currently. - rng_state_shape = (result_infos[2].shape[0] // mesh.size, *result_infos[2].shape[1:]) + rng_state_shape = (seed.shape[0], *result_infos[2].shape[1:]) rng_state = jnp.zeros(rng_state_shape).astype(result_infos[2].dtype) def scan_kv_block(idx, carry): @@ -2437,7 +2443,11 @@ def partition(config, mesh, arg_infos, result_infos): if not is_context_parallel: return FusedAttnBwdPrimitive.partition(config, mesh, arg_infos, result_infos) - arg_shardings = tuple(arg.sharding for arg in arg_infos) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + # Ensure segment_pos gets same sharding as ID. + arg_shardings[-1] = arg_shardings[-3] + arg_shardings[-2] = arg_shardings[-4] + arg_shardings = tuple(arg_shardings) # dq, dk, dv, dbias sharding = q, k, v, bias sharding out_shardings = tuple(arg.sharding for arg in arg_infos[:4]) @@ -2773,10 +2783,13 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) - if 100 in get_all_device_compute_capability(): + # TODO(KshitijLakhani): Add a check for cuDNN version when determinism does get supported on + # sm100+ + compute_capabilities = get_all_device_compute_capability() + if any(x >= 100 for x in compute_capabilities) and not is_hip_extension(): assert not ( attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 - ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + ), "For sm100+, bprop kernel support for dropout + determinism (bias) is not supported" fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 92c09bb68..176e0eadc 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -9,23 +9,17 @@ import warnings from abc import ABCMeta, abstractmethod from functools import partial -from packaging import version from jax.extend import core from jax.interpreters import xla, mlir from jax.experimental.custom_partitioning import custom_partitioning from jax._src.interpreters import batching from jax._src import dispatch +from jax import ffi from .misc import is_hip_extension -import jax import transformer_engine_jax -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - class BasePrimitive(metaclass=ABCMeta): """ @@ -182,7 +176,7 @@ def shardy_sharding_rule(*args): _primitive_registry = {} -def register_primitive(cls): +def register_primitive(cls, outer_only=False): """ Register a JAX primitive and add it to the internal registry. """ @@ -195,13 +189,14 @@ def register_primitive(cls): def name_of_wrapper_p(): return cls.name + "_wrapper" - inner_p = core.Primitive(cls.name) - dispatch.prim_requires_devices_during_lowering.add(inner_p) - inner_p.multiple_results = cls.multiple_results - inner_p.def_impl(partial(xla.apply_primitive, inner_p)) - inner_p.def_abstract_eval(cls.abstract) - mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") - cls.inner_primitive = inner_p + if not outer_only: + inner_p = core.Primitive(cls.name) + dispatch.prim_requires_devices_during_lowering.add(inner_p) + inner_p.multiple_results = cls.multiple_results + inner_p.def_impl(partial(xla.apply_primitive, inner_p)) + inner_p.def_abstract_eval(cls.abstract) + mlir.register_lowering(inner_p, cls.lowering, platform="rocm" if is_hip_extension() else "cuda") + cls.inner_primitive = inner_p outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) @@ -210,16 +205,11 @@ def name_of_wrapper_p(): outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) - if version.parse(jax.__version__) >= version.parse("0.5.0"): - outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, - partition=cls.partition, - sharding_rule=cls.shardy_sharding_rule, - ) - else: - outer_p_lower.def_partition( - infer_sharding_from_operands=cls.infer_sharding_from_operands, partition=cls.partition - ) + outer_p_lower.def_partition( + infer_sharding_from_operands=cls.infer_sharding_from_operands, + partition=cls.partition, + sharding_rule=cls.shardy_sharding_rule, + ) mlir.register_lowering( outer_p, mlir.lower_fun(outer_p_lower, multiple_results=cls.multiple_results) ) @@ -234,7 +224,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F """ Helper function to manage primitive states by name without modifying environment variables. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. - This helper is used in the get_quantize_config().initialize() methods. + This helper is used in the get_quantize_config_with_recipe().initialize() methods. Args: enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ba581c66..2daecedfa 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -8,8 +8,10 @@ import math import operator from collections.abc import Iterable -from typing import Tuple, Sequence, Union +from dataclasses import dataclass from functools import partial, reduce +from typing import Tuple, Sequence, Union +from enum import Enum import warnings import jax @@ -18,8 +20,13 @@ from jax.sharding import NamedSharding, PartitionSpec from jax.experimental.custom_partitioning import SdyShardingRule -import transformer_engine_jax as tex -from transformer_engine_jax import get_num_compute_streams +from transformer_engine_jax import ( + get_num_compute_streams, + JAXX_Collective_Op, + get_device_compute_capability, + #initialize_cgemm_communicator, + #get_cgemm_num_max_streams, +) from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize @@ -30,24 +37,35 @@ AbstractBaseTensor, NoScaleTensor, ScaledTensor, + ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, ScalingMode, Quantizer, GroupedQuantizer, - get_quantize_config, QuantizerSet, QuantizeLayout, noop_quantizer_set, is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, + get_quantize_config_with_recipe, + get_global_quantize_recipe, +) +from .misc import get_padded_spec, is_all_reduce_in_float32 +from ..sharding import ( + global_mesh_resource, + tpsp_axis_size, + dp_or_fsdp_axis_size, ) -from ..sharding import global_mesh_resource -from .misc import get_padded_spec __all__ = [ + "CollectiveOp", + "CollectiveOpSet", + "collective_gemm_bootstrap", + "noop_collective_op_set", "gemm", + "grouped_gemm_copy_group_sizes", "grouped_gemm", "gemm_uses_jax_dot", "sanitize_dims", @@ -65,11 +83,11 @@ def get_cublas_workspace_size_bytes() -> None: """Return workspace size needed for current architecture""" if is_hip_extension(): """Return 64 MiB for gfx50x, 32 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) == 95: + if get_device_compute_capability(0) == 95: return 67_108_864 return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" - if tex.get_device_compute_capability(0) >= 90: + if get_device_compute_capability(0) >= 90: return 33_554_432 return 4_194_304 @@ -135,6 +153,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ need_lhs_colwise = lhs_is_transposed and ( lhs_quantizer.scaling_mode.is_1d_block_scaling() or not is_fp8_gemm_with_all_layouts_supported() + or lhs_quantizer.scaling_mode.is_nvfp4_scaling ) flatten_axis = max(lhs_cdims) + 1 if lhs_is_transposed else min(lhs_cdims) lhs_q = lhs_quantizer.quantize( @@ -150,6 +169,7 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ need_rhs_colwise = not rhs_is_transposed and ( rhs_quantizer.scaling_mode.is_1d_block_scaling() or not is_fp8_gemm_with_all_layouts_supported() + or rhs_quantizer.scaling_mode.is_nvfp4_scaling ) flatten_axis = min(rhs_cdims) if rhs_is_transposed else max(rhs_cdims) + 1 rhs_q = rhs_quantizer.quantize( @@ -162,9 +182,180 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ assert not isinstance(lhs_q, ScaledTensor2x) assert not isinstance(rhs_q, ScaledTensor2x) + def has_rht_applied(q: AbstractBaseTensor) -> bool: + return isinstance(q, ScaledTensor1x) and q.has_rht_applied + + assert has_rht_applied(lhs_q) == has_rht_applied(rhs_q), ( + "With NVFP4_1D_SCALING, if one operand is quantized with RHT, the other must be quantized" + " with RHT as well. This is to ensure the RHT is applied to both and will cancel out in the" + " GEMM." + ) + return lhs_q, rhs_q +def _get_nvfp4_tensor_scale_inv(amax): + DATA_DTYPE_MAX = jnp.finfo(jnp.float4_e2m1fn.dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(jnp.float8_e4m3fn.dtype).max.astype(jnp.float32) + return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) + + +def collective_gemm_bootstrap( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams=3, + compute_stream_priority=0, + communication_stream_priority=0, + num_sm_for_communication=2, + use_ce=True, + aggregate_all_gather=False, +): + """Initialize NCCL communicators for Collective GEMM operations. + + This function sets up the distributed communication infrastructure needed for + tensor parallel collective GEMM operations. It supports two main scenarios: + + 1. **Multi-device per process**: TP domain = single process + - Each process manages multiple GPUs (num_devices_per_process > 1) + - TP group consists of GPUs within the same process + - Example: 2 processes × 4 GPUs each = 8 total ranks, tp_size=4 + + 2. **Single device per process**: TP domain spans multiple processes + - Each process manages one GPU (num_devices_per_process = 1) + - TP group spans across multiple processes + - Example: 8 processes × 1 GPU each = 8 total ranks, tp_size=4 + + Args: + num_total_devices (int): Total number of ranks across all processes. + Must be divisible by num_devices_per_process. + num_devices_per_process (int): Number of GPUs per process. + - For multi-device: equals tp_size (e.g., 4 GPUs per process) + - For single-device: equals 1 (1 GPU per process) + process_id (int): Process identifier (0-based). + Must be in range [0, num_total_devices // num_devices_per_process). + tensor_parallel_size (int): Size of tensor parallel groups. + Must divide num_total_devices evenly. + num_max_streams (int, optional): Maximum number of CUDA streams for overlap. + Higher values enable more parallelism but use more GPU resources. Default: 3. + compute_stream_priority (int, optional): Priority for GEMM computation streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + communication_stream_priority (int, optional): Priority for NCCL communication streams. + Lower values = higher priority. Range: 0 (highest) to 3 (lowest). Default: 0. + num_sm_for_communication (int, optional): Number of streaming multiprocessors + reserved for communication operations. Default: 2. + use_ce (bool, optional): Enable CUDA copy engines for memory transfers. + Can improve performance by offloading memory operations. Default: True. + aggregate_all_gather (bool, optional): Aggregate multiple small all-gather operations + into larger ones for better efficiency. Default: False. + + Raises: + AssertionError: If num_total_devices is not divisible by num_devices_per_process, + or if process_id is out of valid range. + AssertionError: If num_devices_per_process is not 1 (Temporary: only single device per process is supported for now) + RuntimeError: If NCCL initialization fails or if configuration + is invalid (e.g., insufficient GPUs). + + Example: + # Basic initialization (single device per process) + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4 + ) + + # Advanced configuration with custom performance settings + collective_gemm_bootstrap( + num_total_devices=8, + num_devices_per_process=1, + process_id=0, + tensor_parallel_size=4, + num_max_streams=5, # More parallelism + compute_stream_priority=1, # Lower compute priority + communication_stream_priority=0, # Higher comm priority + num_sm_for_communication=4, # More SMs for communication + use_ce=True, # Enable copy engines + aggregate_all_gather=True # Aggregate small operations + ) + + Note: + This function must be called after JAX distributed initialization + and before any collective GEMM operations. Each process should call + this function with its own unique process_id. + """ + if is_hip_extension(): + assert 0, "collective_gemm_bootstrap is not supported for ROCm yet." + assert ( + num_devices_per_process == 1 and jax.local_device_count() == 1 + ), "Only single device per process is supported at the moment!" + assert num_total_devices % num_devices_per_process == 0, ( + f"Invalid num_total_devices={num_total_devices}," + f" num_devices_per_process={num_devices_per_process}" + ) + assert 0 <= process_id < num_total_devices, f"Invalid process_id={process_id}" + initialize_cgemm_communicator( + num_total_devices, + num_devices_per_process, + process_id, + tensor_parallel_size, + num_max_streams, + compute_stream_priority, + communication_stream_priority, + num_sm_for_communication, + use_ce, + aggregate_all_gather, + ) + + +class CollectiveOp(Enum): + "Enum for Collective Type in Collective GEMM" + + NONE = JAXX_Collective_Op.NONE + ALL_GATHER = JAXX_Collective_Op.ALL_GATHER + REDUCE_SCATTER = JAXX_Collective_Op.REDUCE_SCATTER + + @property + def is_all_gather(self) -> bool: + """Check if AllGather""" + return self == CollectiveOp.ALL_GATHER + + @property + def is_reduce_scatter(self) -> bool: + """Check if ReduceScatter""" + return self == CollectiveOp.REDUCE_SCATTER + + @property + def is_none(self) -> bool: + """Check if None""" + return self == CollectiveOp.NONE + + +@dataclass(frozen=True) +class CollectiveOpSet: + """ + A set of CollectiveOp objects that provide complementary collective GEMM configurations for the Forward and Backward passes through Dense-layers. + """ + + forward: CollectiveOp + backward: CollectiveOp + + @staticmethod + def create(forward_collective_op: CollectiveOp): + """Create a set of CollectiveOp for forward and backward passes""" + if forward_collective_op.is_all_gather: + backward_collective_op = CollectiveOp.REDUCE_SCATTER + elif forward_collective_op.is_reduce_scatter: + backward_collective_op = CollectiveOp.ALL_GATHER + else: + backward_collective_op = CollectiveOp.NONE + return CollectiveOpSet(forward=forward_collective_op, backward=backward_collective_op) + + +noop_collective_op_set = CollectiveOpSet.create(forward_collective_op=CollectiveOp.NONE) + + @partial(jax.jit, static_argnums=(1, 2)) def swizzled_scale(scale_inv, flatten_axis, is_colwise): "Swizzle scale_inv via JAX transpose ops" @@ -180,6 +371,28 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) +def get_lhs_axis_boundary(lhs_cdims, is_transposed): + """Get the axis boundary for the LHS operand.""" + return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims) + + +def get_rhs_axis_boundary(rhs_cdims, is_transposed): + """Get the axis boundary for the RHS operand.""" + return min(rhs_cdims) if is_transposed else max(rhs_cdims) + 1 + + +def assert_cublas_requirements(scaling_mode, contracting_size, tensor_name): + """Assert that the given tensor shape and layout meet the requirements for cuBLAS GEMM.""" + if scaling_mode != ScalingMode.NO_SCALING: + # Requirements from https://docs.nvidia.com/cuda/cublas/#tensor-core-usage + alignment = 32 if scaling_mode.is_nvfp4_scaling else 16 + + assert contracting_size % alignment == 0, ( + f"cuBLAS GEMM {tensor_name} tensor's contracting dimension must be a multiple of" + f" {alignment} when using quantized inputs. Got contracting_size={contracting_size}" + ) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -187,7 +400,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12) + impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) inner_primitive = None outer_primitive = None @@ -199,6 +412,8 @@ def abstract( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -206,8 +421,12 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del use_split_accumulator + del use_split_accumulator, transpose_batch_sequence def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -242,7 +461,9 @@ def _dims_are_consecutive(dims): lhs_is_transposed, rhs_is_transposed = _get_gemm_layout(operand_ndims, contracting_dims) if scaling_mode != ScalingMode.NO_SCALING: - assert _compatible_fp8_gemm_dtypes(lhs.dtype, rhs.dtype), ( + assert scaling_mode.is_nvfp4_scaling or _compatible_fp8_gemm_dtypes( + lhs.dtype, rhs.dtype + ), ( "cuBLAS GEMM quantized operands have incompatible data types: " f"{lhs.dtype} x {rhs.dtype}." ) @@ -251,7 +472,7 @@ def _dims_are_consecutive(dims): ), "Quantized cuBLAS GEMM requires inverse scaling factors for both operands." if ( scaling_mode != ScalingMode.MXFP8_1D_SCALING - and not tex.is_non_nt_fp8_gemm_supported() + and not is_fp8_gemm_with_all_layouts_supported() ): assert not lhs_is_transposed and rhs_is_transposed, ( "cuBLAS FP8 GEMM on devices with compute capability < 10.0 (Hopper) " @@ -276,24 +497,33 @@ def _dims_are_consecutive(dims): out_shape = (*lhs_non_contracting_shape, *rhs_non_contracting_shape) output = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) + # Adjust output shape for comm+GEMM overlap + if not collective_op.is_none and not is_outer: # Inner abstract + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + overlap_out_shape = list(out_shape).copy() + if collective_op.is_all_gather: + overlap_out_shape[1] *= tpsp_axis_size() + else: # RS + overlap_out_shape[sequence_dim] = ( + overlap_out_shape[sequence_dim] // tpsp_axis_size() + ) + assert out_dtype == jnp.bfloat16, f"Unsupported out_dtype={out_dtype}" + output = jax.core.ShapedArray(shape=overlap_out_shape, dtype=out_dtype) + # Validate bias - bias_shape = (0,) - bias_dtype = out_dtype if fuse_bias: - expected_bias_size = reduce(operator.mul, rhs_non_contracting_shape) - if not grad: - assert bias.size == expected_bias_size, ( - "cuBLAS GEMM bias tensor has incorrect shape, " - f"expected ({expected_bias_size}, ) but found {bias.shape}." - ) - assert bias.dtype == out_dtype, ( - "cuBLAS GEMM bias tensor has incorrect data type, " - f"expected {bias_dtype} but found {bias.dtype}." - ) - bias_shape = bias.shape - else: - bias_shape = rhs_non_contracting_shape - bias_grad = jax.core.ShapedArray(shape=bias_shape, dtype=bias_dtype) + assert bias.shape == tuple(rhs_non_contracting_shape), ( + "cuBLAS GEMM bias tensor has incorrect shape, " + f"expected ({tuple(rhs_non_contracting_shape)}, ) but found {bias.shape}." + ) + assert bias.dtype == out_dtype, ( + "cuBLAS GEMM bias tensor has incorrect data type, " + f"expected {out_dtype} but found {bias.dtype}." + ) + # WAR: allocate dbias regardless of fuse_bias so that the sharding propagation works as we + # change the fuse_bias value in the sharded_impl + dbias_shape = bias.shape if grad else (0,) + bias_grad = jax.core.ShapedArray(shape=dbias_shape, dtype=bias.dtype) # Validate pre-GeLU pre_gelu_shape = (0,) @@ -313,11 +543,19 @@ def _dims_are_consecutive(dims): f"expected {pre_gelu_dtype} but found {gelu_input.dtype}." ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) + assert alpha.size == 1 and alpha.dtype == jnp.float32 + assert beta.size == 1 and beta.dtype == jnp.float32 # Declare cuBLAS workspace + workspace_size = get_cublas_workspace_size_bytes() + # NVFP4 swizzling happen in via nvte kernel instead of JAX transposes + if scaling_mode.is_nvfp4_scaling: + workspace_size += lhs_scale_inv.size + rhs_scale_inv.size + if not collective_op.is_none: + workspace_size *= get_cgemm_num_max_streams() # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. - workspace_size = get_cublas_workspace_size_bytes() + 256 + workspace_size += 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) return output, bias_grad, pre_gelu_out, workspace @@ -336,6 +574,8 @@ def lowering( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -343,8 +583,12 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): - del out_dtype + del out_dtype, transpose_batch_sequence, sequence_dim, is_outer lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -352,21 +596,45 @@ def lowering( (lhs_aval.ndim, rhs_aval.ndim), (lhs_cdims, rhs_cdims) ) - args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input) + lhs_axis_boundary = get_lhs_axis_boundary(lhs_cdims, lhs_transposed) + lhs_contracting_size = ( + reduce(operator.mul, lhs_aval.shape[lhs_axis_boundary:]) + if lhs_transposed + else reduce(operator.mul, lhs_aval.shape[:lhs_axis_boundary]) + ) + assert_cublas_requirements( + scaling_mode, + lhs_contracting_size, + "LHS", + ) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_cdims, rhs_transposed) + rhs_contracting_size = ( + reduce(operator.mul, rhs_aval.shape[:rhs_axis_boundary]) + if rhs_transposed + else reduce(operator.mul, rhs_aval.shape[rhs_axis_boundary:]) + ) + assert_cublas_requirements( + scaling_mode, + rhs_contracting_size, + "RHS", + ) + + args = (lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta) kwargs = { "scaling_mode": int(scaling_mode.value), - "lhs_axis_boundary": max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - "rhs_axis_boundary": min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, + "lhs_axis_boundary": get_lhs_axis_boundary(lhs_cdims, lhs_transposed), + "rhs_axis_boundary": get_rhs_axis_boundary(rhs_cdims, rhs_transposed), "lhs_transposed": lhs_transposed, "rhs_transposed": rhs_transposed, "fuse_bias": fuse_bias, "fuse_gelu": fuse_gelu, "grad": grad, "use_split_accumulator": use_split_accumulator, + "collective_op": int(collective_op.value), } operand_output_aliases = {} - if fuse_bias and not grad: + if grad: operand_output_aliases.update({4: 1}) # bias <-> bias_grad if fuse_gelu and grad: operand_output_aliases.update({5: 2}) # gelu_input <-> pre_gelu_out @@ -384,6 +652,8 @@ def impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -391,6 +661,10 @@ def impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): if scaling_mode.is_1d_block_scaling(): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) @@ -406,16 +680,47 @@ def impl( rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis ) + # Only perform JAX-based swizzle for MXFP8, NVFP4 swizzle will go though nvte kernel + if scaling_mode.is_mxfp8_scaling: lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) - outputs = GemmPrimitive.inner_primitive.bind( + # Alter lhs blocks so that CGEMM RS outputs correctly + if ( + collective_op.is_reduce_scatter + and not transpose_batch_sequence + and not is_outer + and not lhs.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = lhs.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = lhs.reshape( + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + tpsp_axis_size(), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(2, 0, 1, 3, *range(4, reshaped.ndim)) + lhs = reordered.reshape(original_shape) + + (output, bias_grad, pre_gelu_out, _) = GemmPrimitive.inner_primitive.bind( lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, @@ -423,8 +728,39 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ) - return outputs[:-1] # discard workspace array + # Alter output blocks for CGEMM AG + if ( + collective_op.is_all_gather + and not transpose_batch_sequence + and not is_outer + and not output.shape[0] == 1 + ): + assert sequence_dim == 1, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + original_shape = output.shape + assert original_shape[0] % dp_or_fsdp_axis_size() == 0 or original_shape[0] == 1, ( + f"Original_shape[0]={original_shape[0]} is not divisible by" + f" dp_or_fsdp_axis_size()={dp_or_fsdp_axis_size()}" + ) + assert original_shape[1] % tpsp_axis_size() == 0 or original_shape[1] == 1, ( + f"Original_shape[1]={original_shape[1]} is not divisible by" + f" tpsp_axis_size()={tpsp_axis_size()}" + ) + reshaped = output.reshape( + tpsp_axis_size(), + dp_or_fsdp_axis_size(), + int(original_shape[0] / dp_or_fsdp_axis_size()), + int(original_shape[1] / tpsp_axis_size()), + *original_shape[2:], + ) + reordered = reshaped.transpose(1, 2, 0, 3, *range(4, reshaped.ndim)) + output = reordered.reshape(original_shape) + + return [output, bias_grad, pre_gelu_out] @staticmethod def outer_impl( @@ -434,6 +770,8 @@ def outer_impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -441,6 +779,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ): return GemmPrimitive.impl( lhs, @@ -449,6 +791,8 @@ def outer_impl( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype, contracting_dims, scaling_mode, @@ -456,6 +800,10 @@ def outer_impl( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, ) @staticmethod @@ -469,7 +817,12 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + collective_op, + transpose_batch_sequence, + sequence_dim, + is_outer, ): + del transpose_batch_sequence, sequence_dim, is_outer assert GemmPrimitive.outer_primitive is not None lhs_bdims, _, rhs_bdims, *_ = batch_dims @@ -497,6 +850,10 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + collective_op=collective_op, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=sequence_dim, + is_outer=is_outer, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -505,6 +862,8 @@ def batcher( def _parse_operand_output_specs( arg_infos, contracting_dims, + transpose_batch_sequence, + collective_op, ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) @@ -512,14 +871,12 @@ def _parse_operand_output_specs( # Ensure that tensor sequence parallelism is not used via setting tp_resource if gsr.tp_resource is not None: - for i in range(len(lhs_specs) - 1): - if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: - warnings.warn( - "Tensor sequence parallelism is detected as" - f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" - f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" - " tensor sequence parallelism to avoid potential issues." - ) + if gsr.tp_resource in lhs_specs: + warnings.warn( + "Tensor sequence parallelism is detected as tp_resource='{gsr.tp_resource}'" + " appears in lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource" + " for tensor sequence parallelism to avoid potential issues." + ) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) @@ -541,10 +898,43 @@ def _parse_operand_output_specs( assert reduce_spec is None, "Multiple reduce dimension is detected!" reduce_spec = l + sequence_dim = None + + # Find sequence dimension in lhs_specs if tensor sequence parallel is enabled + # We only do CollectiveGemm AG on the x or dY thus they always the LHS and have sequence dim + if collective_op.is_all_gather: + try: + tpsp_idx = lhs_specs.index(gsr.tpsp_resource) + except ValueError as exc: + raise ValueError( + f"tpsp_resource '{gsr.tpsp_resource}' is not found in lhs_specs: {lhs_specs}." + " Please check your sharding configuration." + ) from exc + sequence_dim = tpsp_idx + assert (sequence_dim == 1) ^ transpose_batch_sequence, ( + "CollectiveGEMM supports only (sequence_dim=1 and transpose_batch_sequence=False)" + " or (sequence_dim=0 and transpose_batch_sequence=True). Received:" + f" sequence_dim={sequence_dim}," + f" transpose_batch_sequence={transpose_batch_sequence}." + ) + + elif collective_op.is_reduce_scatter: + assert reduce_spec == gsr.tpsp_resource, ( + "Only CollectiveGemm RS with the Reduction over the TPSP axis is supported! Got" + f" reduce_spec={reduce_spec}, tpsp_resource={gsr.tpsp_resource}" + ) + sequence_dim = int(not transpose_batch_sequence) + if reduce_spec is not None: # Other non-reduce cdims (if exists) need to be unsharded lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) - rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) + # Only do AG Sequence dim if not Overlap + if collective_op.is_all_gather: + rhs_cspecs = tuple( + s if s in (reduce_spec, gsr.tpsp_resource) else None for s in rhs_cspecs + ) + else: + rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. @@ -564,13 +954,31 @@ def _parse_operand_output_specs( for spec in rhs_non_cspecs ) - # Non-contracting dims of LHS to be gathered along the SP axis. - # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for - # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. - lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) + # Only do AG Sequence dim if not Overlap + if not collective_op.is_all_gather: + # Non-contracting dims of LHS to be gathered along the SP axis. + # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for + # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. + lhs_non_cspecs = tuple( + None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs + ) out_specs = lhs_non_cspecs + rhs_non_cspecs + # Only do AG Sequence dim if not Overlap RS + if collective_op.is_all_gather: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = out_specs[:sequence_dim] + (None,) + out_specs[sequence_dim + 1 :] + elif collective_op.is_reduce_scatter: + assert sequence_dim <= len( + lhs_non_cspecs + ), f"Sequence dim {sequence_dim} is out of bounds for lhs_non_cspecs: {lhs_non_cspecs}" + out_specs = ( + out_specs[:sequence_dim] + (gsr.tpsp_resource,) + out_specs[sequence_dim + 1 :] + ) + # specs = merge(cspecs, non_cspecs) lhs_specs, rhs_specs = map( lambda cdims, cspecs, non_cspecs: ( @@ -585,10 +993,14 @@ def _parse_operand_output_specs( bias_specs = tuple(list(rhs_non_cspecs).copy()) gelu_specs = tuple(list(out_specs).copy()) + if not collective_op.is_none: + assert sequence_dim >= 0, f"Invalid sequence_dim. Got sequence_dim={sequence_dim}" + return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), reduce_spec, + sequence_dim, ) @staticmethod @@ -600,6 +1012,10 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, @@ -607,17 +1023,21 @@ def infer_sharding_from_operands( del ( out_dtype, scaling_mode, - grad, + use_split_accumulator, + result_infos, + is_outer, + sequence_dim, ) - del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( + GemmPrimitive._parse_operand_output_specs( + arg_infos, contracting_dims, transpose_batch_sequence, collective_op + ) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: + # Discard dbias gradient spec if there is no bias and grad fusion + if not (fuse_bias and grad): dbias_specs = (None,) dbias_sharding = NamedSharding(mesh, PartitionSpec(*dbias_specs)) @@ -637,20 +1057,29 @@ def partition( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, arg_infos, result_infos, ): - del result_infos + del result_infos, is_outer, sequence_dim ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), reduce_spec, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) + inferred_sequence_dim, + ) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + transpose_batch_sequence, + collective_op, + ) - # Assemble argument shardings - # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. + # Block scale inverses match their operands, but tensor scale inverses are unsharded. none_sharding = NamedSharding(mesh, PartitionSpec(None)) lhs_sharding = NamedSharding(mesh, PartitionSpec(*lhs_specs)) rhs_sharding = NamedSharding(mesh, PartitionSpec(*rhs_specs)) @@ -671,11 +1100,14 @@ def partition( gelu_input_specs = (None,) arg_shardings += (NamedSharding(mesh, PartitionSpec(*gelu_input_specs)),) + # Alpha, beta + arg_shardings += (none_sharding, none_sharding) + # Assemble output shardings out_shardings = [NamedSharding(mesh, PartitionSpec(*out_specs))] - # Discard bias gradient spec if there is no bias fusion - if not fuse_bias: + # Discard bias gradient spec if there is no bias and grad fusion + if not (fuse_bias and grad): dbias_specs = (None,) out_shardings.append(NamedSharding(mesh, PartitionSpec(*dbias_specs))) @@ -684,7 +1116,9 @@ def partition( pre_gelu_specs = (None,) out_shardings.append(NamedSharding(mesh, PartitionSpec(*pre_gelu_specs))) - def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): + def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input, alpha, beta): + # We should not fuse bias in the output reduction case + sharded_fuse_bias = fuse_bias and reduce_spec is None outputs = GemmPrimitive.impl( lhs, lhs_scale_inv, @@ -692,18 +1126,32 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=contracting_dims, scaling_mode=scaling_mode, - fuse_bias=fuse_bias, + fuse_bias=sharded_fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=inferred_sequence_dim, + is_outer=False, + collective_op=collective_op, ) - # All-Reduce GEMM output if reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + if not collective_op.is_reduce_scatter: + if is_all_reduce_in_float32(): # For unittest only + outputs[0] = jax.lax.psum( + outputs[0].astype(jnp.float32), reduce_spec + ).astype(out_dtype) + else: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + + if fuse_bias: # TODO(Phuong): rename fuse_bias to has_bias + outputs[0] += bias return outputs @@ -718,20 +1166,24 @@ def shardy_sharding_rule( fuse_gelu, grad, use_split_accumulator, + transpose_batch_sequence, + sequence_dim, + is_outer, + collective_op, mesh, operand_types, result_types, ): - del out_dtype, grad, use_split_accumulator - del mesh, result_types + del out_dtype, use_split_accumulator + del mesh, result_types, transpose_batch_sequence, sequence_dim, is_outer - prefix = "GemmPrimitive_" + if not collective_op.is_none: + raise NotImplementedError( + "CollectiveGEMM with Shardy propagation is not supported yet! Please turn off" + " Shardy by exporting env var JAX_USE_SHARDY_PARTITIONER=false" + ) - warnings.warn( - "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," - " please turn off Shardy by exporting the environment variable" - " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems." - ) + prefix = "Gemm_" def _generate_operand_rules(name, ndim, cdims): specs = [] @@ -759,19 +1211,17 @@ def _generate_operand_rules(name, ndim, cdims): lhs_scale_specs = ("…1",) rhs_scale_specs = ("…2",) if scaling_mode.is_1d_block_scaling(): - # Shardy rules for MXFP8 scales cannot be related to the operands because of the - # global-unpadding and local-padding workflow. This can potentially insert expensive - # re-shards in the partition call later if the scales are not already sharded correctly. - lhs_scale_specs, rhs_scale_specs = map( - lambda specs: tuple(spec.replace(prefix, prefix + "scale_inv_") for spec in specs), - (lhs_specs, rhs_specs), - ) + lhs_scale_specs = lhs_specs + rhs_scale_specs = rhs_specs lhs_non_cspec = tuple(lhs_specs[i] for i in range(operand_ndims[0]) if i not in lhs_cdims) rhs_non_cspec = tuple(rhs_specs[i] for i in range(operand_ndims[1]) if i not in rhs_cdims) out_spec = (*lhs_non_cspec, *rhs_non_cspec) bias_spec = rhs_non_cspec if fuse_bias else ("…4",) gelu_spec = out_spec if fuse_gelu else ("…5",) + alpha_spec = ("_6",) + beta_spec = ("_7",) + dbias_spec = bias_spec if grad else ("…8") return SdyShardingRule( operand_mappings=( @@ -781,10 +1231,12 @@ def _generate_operand_rules(name, ndim, cdims): rhs_scale_specs, bias_spec, gelu_spec, + alpha_spec, + beta_spec, ), result_mappings=( out_spec, - bias_spec, + dbias_spec, gelu_spec, ), ) @@ -809,21 +1261,39 @@ def _te_gemm( fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, - use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, + use_split_accumulator: bool = None, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, ) -> Tuple[jax.Array, ...]: + if grad or fuse_gelu: + warnings.warn( + "GEMM + fused grad or fused gelu is not well tested and will be deprecated in the" + " future", + DeprecationWarning, + ) + + if use_split_accumulator is None: + # TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also + # use context of the GEMM type so we can decide between fprop, dgrad, and wgrad + use_split_accumulator = get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).FP8_2X_ACC_FPROP + # Prepare non-quantized GEMM operands lhs_data = lhs rhs_data = rhs lhs_scale_inv = jnp.empty(0, dtype=jnp.float32) rhs_scale_inv = jnp.empty(0, dtype=jnp.float32) scaling_mode = ScalingMode.NO_SCALING + lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) + lhs_amax = rhs_amax = None # Extract GEMM custom op inputs from quantized operands if isinstance(lhs_q, ScaledTensor): assert isinstance(rhs_q, ScaledTensor) or rhs_quantizer is not None, ( @@ -838,6 +1308,7 @@ def _te_gemm( lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) + lhs_amax = lhs_q.amax if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -847,7 +1318,11 @@ def _te_gemm( if isinstance(rhs_q, ScaledTensor2x): # Choose the quantization of the contracting dimension(s) rhs_q = rhs_q.get_rowwise_tensor() if rhs_is_transposed else rhs_q.get_colwise_tensor() - assert rhs_q.scaling_mode == lhs_q.scaling_mode, ( + assert ( + rhs_q.scaling_mode == lhs_q.scaling_mode + or rhs_q.scaling_mode.is_nvfp4_scaling + and lhs_q.scaling_mode.is_nvfp4_scaling + ), ( "cuBLAS GEMM quantized operands have mismatched scaling types, " f"LHS:{lhs_q.scaling_mode} x RHS:{rhs_q.scaling_mode}." ) @@ -855,6 +1330,15 @@ def _te_gemm( rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) + rhs_amax = rhs_q.amax + + alpha = jnp.ones((1,), jnp.float32) + beta = jnp.zeros((1,), jnp.float32) + if scaling_mode.is_nvfp4_scaling: + assert lhs_amax is not None and rhs_amax is not None + lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs_amax) + rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs_amax) + alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -870,6 +1354,8 @@ def _te_gemm( rhs_scale_inv, bias, gelu_input, + alpha, + beta, out_dtype=out_dtype, contracting_dims=(lhs_cdims, rhs_cdims), scaling_mode=scaling_mode, @@ -877,9 +1363,70 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + transpose_batch_sequence=transpose_batch_sequence, + sequence_dim=-1, # Dummy value and will be set in the primitive + is_outer=True, + collective_op=collective_op, ) +class GroupedGemmCopySizesPrimitive(BasePrimitive): + """ + Primitive for async copying group sizes from device to host + """ + + name = "te_grouped_gemm_d2h_group_sizes_ffi" + multiple_results = False + impl_static_args = (1,) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract( + group_sizes_aval, + *, + num_gemms, + ): + del num_gemms + out_aval = group_sizes_aval + return out_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + out = GroupedGemmCopySizesPrimitive.abstract(*args, **kwargs) + return out + + @staticmethod + def lowering( + ctx, + group_sizes, + num_gemms, + ): + return jax.ffi.ffi_lowering( + GroupedGemmCopySizesPrimitive.name, + operand_output_aliases={0: 0}, # Mark num_gemms as the output + )( + ctx, + group_sizes, + num_gemms=num_gemms, + ) + + @staticmethod + def impl( + group_sizes, + num_gemms, + ): + assert GroupedGemmCopySizesPrimitive.inner_primitive is not None + out = GroupedGemmCopySizesPrimitive.inner_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + +register_primitive(GroupedGemmCopySizesPrimitive) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM @@ -887,7 +1434,7 @@ class GroupedGemmPrimitive(BasePrimitive): name = "te_grouped_gemm_ffi" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (7, 8, 9, 10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -910,6 +1457,7 @@ def abstract( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): """ Grouped GEMM operation. @@ -937,7 +1485,7 @@ def abstract( A jnp.ndarray containing the result of the grouped GEMM operation """ del lhs_data_aval, rhs_data_aval, bias_aval, group_offset_aval - del K, lhs_is_trans, rhs_is_trans, has_bias + del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes # TODO(Phuong): move some shape checks from Cpp to here workspace_size = get_cublas_workspace_size_bytes() * num_cublas_streams workspace_alignment_padding = 256 @@ -984,6 +1532,7 @@ def lowering( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): del out_dtype return jax.ffi.ffi_lowering(GroupedGemmPrimitive.name)( @@ -997,6 +1546,7 @@ def lowering( scaling_mode=scaling_mode.value, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @staticmethod @@ -1017,6 +1567,7 @@ def impl( out_dtype, has_bias, is_grouped_dense_wgrad, + use_async_d2h_group_sizes, ): assert GroupedGemmPrimitive.inner_primitive is not None (out, _) = GroupedGemmPrimitive.inner_primitive.bind( @@ -1036,6 +1587,7 @@ def impl( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return (out,) @@ -1094,15 +1646,17 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): @partial(jax.jit, static_argnums=(2,)) -def _jax_gemm_mxfp8_1d( +def _jax_scaled_matmul( lhs: ScaledTensor, rhs: ScaledTensor, dim_nums: Tuple[Tuple[Sequence[int], Sequence[int]]] ): """ JAX GEMM for MXFP8 via scaled_matmul """ - assert ( - rhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING - ), "rhs does not have MXFP8 1D scaling mode" + assert rhs.scaling_mode in ( + ScalingMode.MXFP8_1D_SCALING, + ScalingMode.NVFP4_1D_SCALING, + ScalingMode.NVFP4_2D_SCALING, + ), f"rhs does not have MXFP8 or NVFP4 scaling mode, got rhs.scaling_mode={rhs.scaling_mode}" (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums @@ -1117,21 +1671,48 @@ def _jax_gemm_mxfp8_1d( f" {rhs.is_colwise}" ) + if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: + out_dtype = lhs.dq_dtype + assert ( + lhs.data_layout == "N" and rhs.data_layout == "N" + ), f"Got lhs.data_layout={lhs.data_layout}, rhs.data_layout={rhs.data_layout}" + else: + if lhs.data_layout == "T": + lhs_contract = transpose_dims( + lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis + ) + if rhs.data_layout == "T": + rhs_contract = transpose_dims( + rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis + ) + out_dtype = jnp.float32 + # Reshape + Transpose (if needed) # [..., M, K] -> [1, reduce(..., M), K] # [..., K, M] -> [1, reduce(..., M), K] - lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch)) - rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch)) - lhs_scale_3d = _shape_normalization(lhs.scale_inv, (lhs_contract, lhs_batch)) - rhs_scale_3d = _shape_normalization(rhs.scale_inv, (rhs_contract, rhs_batch)) + lhs_3d = _shape_normalization(lhs.data, (lhs_contract, lhs_batch), lhs.data_layout == "T") + rhs_3d = _shape_normalization(rhs.data, (rhs_contract, rhs_batch), rhs.data_layout == "T") + lhs_scale_3d = _shape_normalization( + lhs.scale_inv, (lhs_contract, lhs_batch), lhs.data_layout == "T" + ) + rhs_scale_3d = _shape_normalization( + rhs.scale_inv, (rhs_contract, rhs_batch), rhs.data_layout == "T" + ) # JAX scaled_matmul only supports NT now (TN-gemm) # * Expected shape: # * lhs_data (B, M, K) * rhs_data (B, N, K) # * lhs_scale (B, M, K_block) * rhs_scale (B, N, K_block) out_3d = jax.nn.scaled_matmul( - lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=lhs.dq_dtype + lhs_3d, rhs_3d, lhs_scale_3d, rhs_scale_3d, preferred_element_type=out_dtype ) + if lhs.scaling_mode.is_nvfp4_scaling: + assert lhs.amax is not None and rhs.amax is not None + lhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(lhs.amax) + rhs_tensor_scale_inv = _get_nvfp4_tensor_scale_inv(rhs.amax) + alpha = lhs_tensor_scale_inv * rhs_tensor_scale_inv + out_3d = (out_3d * alpha).astype(lhs.dq_dtype) + # Reshape [1, reduce(..., M), N] -> [..., M, N] lhs_remain_shape = tuple( lhs.data.shape[dim] for dim in range(len(lhs.data.shape)) if dim not in lhs_contract @@ -1140,6 +1721,7 @@ def _jax_gemm_mxfp8_1d( rhs.data.shape[dim] for dim in range(len(rhs.data.shape)) if dim not in rhs_contract ) out = out_3d.reshape(*lhs_remain_shape, *rhs_remain_shape) + return out @@ -1155,27 +1737,32 @@ def _jax_gemm( """ dim_nums = (contracting_dims, ((), ())) - def _jax_gemm_fp8_impl(lhs, rhs): + def _jax_gemm_impl(lhs, rhs): if lhs.scaling_mode.is_tensor_scaling(): assert ( rhs.scaling_mode == lhs.scaling_mode ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" + + # TODO(jberchtold): Rework GEMM API to provide the context here instead of relying on global state and also + # use context of the GEMM type so we can decide between fprop, dgrad, and wgrad + use_split_accumulator = get_quantize_config_with_recipe( + get_global_quantize_recipe() + ).FP8_2X_ACC_FPROP + precision = ( - jax.lax.Precision.HIGHEST - if get_quantize_config().FP8_2X_ACC_FPROP - else jax.lax.Precision.DEFAULT + jax.lax.Precision.HIGHEST if use_split_accumulator else jax.lax.Precision.DEFAULT ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) - if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) + if lhs.scaling_mode.is_1d_block_scaling: + return _jax_scaled_matmul(lhs, rhs, dim_nums) raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}") lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) if isinstance(lhs_q, ScaledTensor) and isinstance(rhs_q, ScaledTensor): - return _jax_gemm_fp8_impl(lhs_q, rhs_q) + return _jax_gemm_impl(lhs_q, rhs_q) if ( isinstance(lhs, jnp.ndarray) @@ -1194,6 +1781,8 @@ def gemm( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, + transpose_batch_sequence: bool = False, + collective_op: CollectiveOp = CollectiveOp.NONE, **kwargs, ) -> Tuple[jnp.ndarray, ...]: r"""General matrix multiplication with optional quantization. @@ -1227,8 +1816,11 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only - supported with TE's custom call to cuBLAS GEMM. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + transpose_batch_sequence: bool, default = False + Transpose the batch and sequence dimensions of the input tensor. + collective_op: CollectiveOp, default = CollectiveOp.NONE + Collective operation type for collective GEMM. Returns ------- @@ -1259,6 +1851,7 @@ def gemm( rhs_quantizer = quantizer_set.kernel # Fall back on a native JAX implementation when the custom call to cuBLAS GEMM is disabled + # TODO(Phuong): fuse_bias -> has_bias and has_bias = bias is not None fuse_bias = kwargs.get("fuse_bias", False) fuse_gelu = kwargs.get("fuse_gelu", False) if not GemmPrimitive.enabled(): @@ -1272,6 +1865,7 @@ def gemm( "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) + assert collective_op.is_none, "JAX GEMM does not support collective GEMM" return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) outputs = _te_gemm( @@ -1280,6 +1874,8 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op, **kwargs, ) @@ -1295,6 +1891,24 @@ def gemm( return clean_outputs +def grouped_gemm_copy_group_sizes( + group_sizes: jnp.ndarray, + num_gemms: int, +) -> jnp.ndarray: + """ + Async copy group sizes from device to host + + Args: + group_sizes: 1D array containing the sizes of each group + num_gemms: number of grouped gemm calls to be made + """ + out = GroupedGemmCopySizesPrimitive.outer_primitive.bind( + group_sizes, + num_gemms=num_gemms, + ) + return out + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], @@ -1305,6 +1919,7 @@ def grouped_gemm( preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + use_async_d2h_group_sizes: bool = False, ) -> jnp.ndarray: """ Grouped GEMM operation. @@ -1488,5 +2103,6 @@ def grouped_gemm( out_dtype=out_dtype, has_bias=has_bias, is_grouped_dense_wgrad=is_grouped_dense_wgrad, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) return out diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index e7464a6da..6c4be68ec 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -8,8 +8,6 @@ import os import functools from typing import Tuple -from importlib.metadata import version as get_pkg_version -from packaging.version import Version as PkgVersion import numpy as np @@ -78,7 +76,8 @@ def jax_dtype_to_te_dtype(jax_dtype): jnp.int64.dtype: TEDType.kInt64, get_jnp_float8_e4m3_type().dtype: TEDType.kFloat8E4M3, get_jnp_float8_e5m2_type().dtype: TEDType.kFloat8E5M2, - jnp.uint8.dtype: TEDType.kByte, + jnp.float8_e8m0fnu.dtype: TEDType.kFloat8E8M0, + jnp.float4_e2m1fn.dtype: TEDType.kFloat4E2M1, } if jax_dtype not in converter: @@ -157,16 +156,6 @@ def get_cudnn_version() -> Tuple[int, int, int]: return (major, minor, patch) -@functools.lru_cache(maxsize=None) -def jax_version_meet_requirement(version: str): - """ - Helper function checking if required JAX version is available - """ - jax_version = PkgVersion(get_pkg_version("jax")) - jax_version_required = PkgVersion(version) - return jax_version >= jax_version_required - - def get_xla_flag(flag: str, default=None, cast=str): """ Returns the value of a flag/option in XLA_FLAGS environment variable if present or returns the default value. @@ -224,7 +213,9 @@ def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quant break # _quantize_dbias_impl forcing 1x quantization for tensor scaling switches q_layout to ROWWISE, # but this fails when bias fusion is turned on with arch < 100. - force_1x_quantization = quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() + force_1x_quantization = ( + quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise + ) return ( (force_1x_quantization or quantizer.q_layout == QuantizeLayout.ROWWISE) and arch_l_100 @@ -246,7 +237,9 @@ def try_apply_delayed_scaling_2x_war(f, *args, quantizer=None, flatten_axis=-1, @return: the output of 'f' with the colwise output calculated """ should_apply_war = ( - quantizer is not None and quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x() + quantizer is not None + and quantizer.scaling_mode.is_tensor_scaling() + and quantizer.q_layout.is_rowwise_colwise ) if not should_apply_war: return None @@ -299,3 +292,11 @@ def duplicate_with_new_description(self, desc: str): Create a new NamedSharding with the same mesh and spec but with a new description. """ return NamedSharding(self.mesh, self.spec, desc=desc) + + +@functools.lru_cache(maxsize=1) +def is_all_reduce_in_float32(): + """ + Check if all-reduce is in float32 + """ + return os.getenv("NVTE_JAX_ALL_REDUCE_IN_FP32", "0") == "1" diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 89731e24a..e53d63625 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,14 +8,13 @@ import warnings import operator from functools import partial, cache, reduce -from typing import Optional, Union from packaging import version +from typing import Optional, Union import jax import jax.numpy as jnp -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule +from jax import dtypes, ffi +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.interpreters.mlir import ir from jax.sharding import PartitionSpec from .misc import is_hip_extension @@ -32,8 +31,11 @@ NamedSharding, get_cudnn_version, ) -from .quantization import _quantize_dbias_impl -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from .quantization import quantize, AmaxScope +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp_tpsp, +) from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, @@ -42,11 +44,6 @@ ScalingMode, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "layernorm_fwd", @@ -102,7 +99,7 @@ class NormFwdPrimitive(BasePrimitive): name = "te_norm_forward_ffi" multiple_results = True - impl_static_args = (4, 5, 6, 7, 8, 9, 10, 11) + impl_static_args = (5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @@ -110,6 +107,7 @@ class NormFwdPrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, gamma_aval, beta_aval, *, @@ -118,17 +116,29 @@ def abstract( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ LayerNorm fwd inner primitive abstract """ + del amax_scope, transpose_batch_sequence + assert not output_amax_when_no_scaling or ( + scaling_mode == ScalingMode.NO_SCALING.value + and not is_norm_fwd_cudnn_enabled(scaling_mode) + ), ( + f"scaling_mode = {scaling_mode}," + f" use_cudnn_norm_fwd={is_norm_fwd_cudnn_enabled(scaling_mode)}" + ) x_dtype = dtypes.canonicalize_dtype(x_aval.dtype) assert x_dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 assert ( scaling_mode != ScalingMode.MXFP8_1D_SCALING.value @@ -142,6 +152,13 @@ def abstract( "Current tensor scaling is not supported for fused norm and quantization. Please do" " norm in higher-precision then quantize with current tensor scaling." ) + assert not ScalingMode(scaling_mode).is_nvfp4_scaling, ( + "NVFP4 block scaling is not yet supported for fused norm and quantization." + " Please do norm in higher-precision then quantize with current tensor scaling." + ) + assert ( + not quantize_layout.is_colwise_only + ), "Fused norm with colwise-only quantization is not supported." mu_rsigama_dtype = jnp.float32 @@ -159,7 +176,7 @@ def abstract( updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) - colwise_out_shape = x_aval.shape if is_2x else (1,) + colwise_out_shape = x_aval.shape if quantize_layout.has_colwise else (1,) colwise_out_aval = jax.core.ShapedArray(shape=colwise_out_shape, dtype=out_dtype) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -167,7 +184,7 @@ def abstract( ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer) scale_inv_aval = jax.core.ShapedArray(shape=rowwise_scale_inv_shape, dtype=scale_dtype) - colwise_scale_inv_shape = colwise_scale_inv_shape if is_2x else (1,) + colwise_scale_inv_shape = colwise_scale_inv_shape if quantize_layout.has_colwise else (1,) colwise_scale_inv_aval = jax.core.ShapedArray( shape=colwise_scale_inv_shape, dtype=scale_dtype ) @@ -183,7 +200,7 @@ def abstract( zero_centered_gamma, epsilon, get_forward_sm_margin(), - is_2x, + True, # is_training ) wkspace_aval = jax.core.ShapedArray( shape=wkspace_info[0], dtype=te_dtype_to_jax_dtype(wkspace_info[1]) @@ -230,6 +247,7 @@ def lowering( ctx, x, scale, + amax, gamma, beta, *, @@ -238,18 +256,22 @@ def lowering( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ LayerNorm fwd lowering rules """ - del out_dtype, scale_dtype, is_outer - x_aval, scale_aval, gamma_aval, beta_aval = ctx.avals_in + del out_dtype, scale_dtype, is_outer, amax_scope, transpose_batch_sequence + x_aval, scale_aval, amax_aval, gamma_aval, beta_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert amax_aval is None or amax_aval.dtype == jnp.float32 g_type = ir.RankedTensorType(gamma.type) g_shape = g_type.shape @@ -261,10 +283,14 @@ def lowering( assert g_shape == b_shape sm_margin = get_forward_sm_margin() - return ffi.ffi_lowering(NormFwdPrimitive.name)( + return ffi.ffi_lowering( + NormFwdPrimitive.name, + operand_output_aliases={2: 4}, # amax <-> updated_amax + )( ctx, x, scale, + amax, gamma, beta, norm_type=norm_type.value, @@ -272,13 +298,15 @@ def lowering( epsilon=epsilon, sm_margin=sm_margin, scaling_mode=scaling_mode.value, - is_2x=is_2x, + quantize_layout=quantize_layout.value.value, + output_amax_when_no_scaling=output_amax_when_no_scaling, ) @staticmethod def impl( x, scale, + amax, gamma, beta, norm_type, @@ -286,8 +314,11 @@ def impl( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ @@ -307,6 +338,7 @@ def impl( ) = NormFwdPrimitive.inner_primitive.bind( x, scale, + amax, gamma, beta, norm_type=norm_type, @@ -314,8 +346,11 @@ def impl( epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=False, ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( @@ -325,7 +360,7 @@ def impl( scale_inv = scale_inv.flatten()[: reduce(operator.mul, rowwise_scale_inv_shape, 1)].reshape( rowwise_scale_inv_shape ) - if is_2x: + if quantize_layout.has_colwise: colwise_scale_inv = colwise_scale_inv.flatten()[ : reduce(operator.mul, colwise_scale_inv_shape, 1) ].reshape(colwise_scale_inv_shape) @@ -349,18 +384,20 @@ def batcher( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, ): """ to describe batch rules for vmap """ - del is_outer check_valid_batch_dims(batch_dims) assert NormFwdPrimitive.outer_primitive is not None - x, scale, gamma, beta = batched_args - x_bdim, scale_bdim, _, _ = batch_dims + x, scale, amax, gamma, beta = batched_args + x_bdim, scale_bdim, _, _, _ = batch_dims out_bdims = ( x_bdim, # rowwise output @@ -373,8 +410,9 @@ def batcher( ) return ( NormFwdPrimitive.outer_primitive.bind( - scale, x, + scale, + amax, gamma, beta, norm_type=norm_type, @@ -382,8 +420,12 @@ def batcher( epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, + is_outer=is_outer, ), out_bdims, ) @@ -395,17 +437,21 @@ def infer_sharding_from_operands( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, result_infos, ): del zero_centered_gamma, epsilon, out_dtype, result_infos - del scale_dtype, is_outer + del scale_dtype, is_outer, amax_scope, transpose_batch_sequence, output_amax_when_no_scaling x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: warnings.warn( @@ -415,7 +461,7 @@ def infer_sharding_from_operands( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) @@ -425,9 +471,9 @@ def infer_sharding_from_operands( mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_inv_spec = amax_spec = (None,) + scale_inv_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec + scale_inv_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec @@ -453,8 +499,11 @@ def partition( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, arg_infos, @@ -463,8 +512,9 @@ def partition( del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) scale_spec = get_padded_spec(arg_infos[1]) - g_spec = get_padded_spec(arg_infos[2]) - b_spec = get_padded_spec(arg_infos[3]) + amax_spec = get_padded_spec(arg_infos[2]) + g_spec = get_padded_spec(arg_infos[3]) + b_spec = get_padded_spec(arg_infos[4]) out_spec = (*x_spec[:-1], None) if x_spec[-1] is not None: @@ -485,7 +535,7 @@ def partition( ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.out") - colwise_out_spec = out_spec if is_2x else (None,) + colwise_out_spec = out_spec if quantize_layout.has_colwise else (None,) colwise_out_sharding = NamedSharding( mesh, PartitionSpec(*colwise_out_spec), desc="NormFwdPrimitive.colwise_out" ) @@ -495,9 +545,9 @@ def partition( mu_spec = x_spec[:-1] if norm_type == NVTE_Norm_Type.LayerNorm else (None,) mu_sharding = NamedSharding(mesh, PartitionSpec(*mu_spec), desc="NormFwdPrimitive.mu") - scale_inv_spec = amax_spec = (None,) + scale_inv_spec = (None,) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec + scale_inv_spec = scale_spec elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = out_spec @@ -509,10 +559,10 @@ def partition( arg_shardings = list(arg_i.sharding for arg_i in arg_infos) # Enforce no sharding of hidden dim for x, gamma and beta arg_shardings[0] = NamedSharding(mesh, PartitionSpec(*out_spec), desc="NormFwdPrimitive.x") - arg_shardings[2] = NamedSharding( + arg_shardings[3] = NamedSharding( mesh, PartitionSpec(*g_spec[:-1], None), desc="NormFwdPrimitive.gamma" ) - arg_shardings[3] = NamedSharding( + arg_shardings[4] = NamedSharding( mesh, PartitionSpec(*b_spec[:-1], None), desc="NormFwdPrimitive.beta" ) arg_shardings = tuple(arg_shardings) @@ -526,19 +576,20 @@ def partition( rsigma_sharding, ) - def sharded_impl(x, scale, gamma, beta): + def sharded_impl(x, scale, amax, gamma, beta): # expect tp and dp giving same shape, or tp being same shape as global ( local_x, local_colwise_x, local_scale_inv, local_colwise_scale_inv, - local_amax, + local_updated_amax, local_mu, local_rsigma, ) = NormFwdPrimitive.impl( x, scale, + amax, gamma, beta, norm_type=norm_type, @@ -546,14 +597,23 @@ def sharded_impl(x, scale, gamma, beta): epsilon=epsilon, out_dtype=out_dtype, scaling_mode=scaling_mode, - is_2x=is_2x, + quantize_layout=quantize_layout, scale_dtype=scale_dtype, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax, mesh) + global_updated_amax = all_reduce_max_along_all_axes_except_PP( + local_updated_amax, mesh + ) + elif scaling_mode == ScalingMode.NO_SCALING.value and output_amax_when_no_scaling: + global_updated_amax = amax_scope.all_reduce_amax_along_TPSP_and_FSDP( + local_updated_amax, x_spec, transpose_batch_sequence, mesh + ) else: - global_updated_amax = local_amax + global_updated_amax = local_updated_amax return ( local_x, @@ -574,8 +634,11 @@ def shardy_sharding_rule( epsilon, out_dtype, scaling_mode, - is_2x, + quantize_layout, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, value_types, @@ -588,34 +651,42 @@ def shardy_sharding_rule( epsilon, out_dtype, scale_dtype, + amax_scope, + transpose_batch_sequence, + output_amax_when_no_scaling, is_outer, mesh, result_types, ) - prefix = "NormFwdPrimitive_" + prefix = "NormFwd" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), unique_var=prefix + "x", flatten_axis=-1 + value_types[0].shape, + unique_var=prefix, + flatten_axis=-1, + q_layout=quantize_layout, ) - x_axes = scale_rules.input_spec + input_spec = scale_rules.input_spec - out = x_axes - colwise_out = out if is_2x else (prefix + "out_colwise",) - rsigma = x_axes[:-1] - mu = (prefix + "mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma - amax = (prefix + "amax",) + rsigma = input_spec[:-1] + mu = (BATCHING + prefix + "_mu",) if norm_type == NVTE_Norm_Type.RMSNorm else rsigma + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) + gamma = (BATCHING + prefix + "_gamma",) + beta = (BATCHING + prefix + "_beta",) return SdyShardingRule( - (x_axes, ("…1",), ("…2",), ("…3",)), + (input_spec, scale, amax, gamma, beta), ( - out, - colwise_out, - scale_rules.rowwise_rule, - scale_rules.colwise_rule, + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, amax, mu, rsigma, ), + **scale_rules.factor_sizes, ) @@ -812,9 +883,9 @@ def sharded_impl(dz, x, mu, rsigma, gamma): norm_type=norm_type, zero_centered_gamma=zero_centered_gamma, ) - global_dgamma = all_reduce_sum_along_dp_fsdp(local_dgamma, mesh) + global_dgamma = all_reduce_sum_along_dp_fsdp_tpsp(local_dgamma, mesh) if norm_type == NVTE_Norm_Type.LayerNorm: - global_dbeta = all_reduce_sum_along_dp_fsdp(local_dbeta, mesh) + global_dbeta = all_reduce_sum_along_dp_fsdp_tpsp(local_dbeta, mesh) else: global_dbeta = local_dbeta return local_dx, global_dgamma, global_dbeta @@ -891,7 +962,10 @@ def layernorm_fwd( beta: jnp.ndarray, zero_centered_gamma: bool, epsilon: float, - quantizer: Optional[Quantizer], + quantizer: Optional[Quantizer] = None, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray, jnp.ndarray]: """Layer normalization forward pass with optional quantization. @@ -905,6 +979,8 @@ def layernorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -916,10 +992,19 @@ def layernorm_fwd( - Reciprocal of the standard deviation of the input tensor. Shape: (..., 1) """ if not NormFwdPrimitive.enabled(): - return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon) + if quantizer is not None: + output = quantize( + output, + quantizer, + flatten_axis=-1, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + return (output, mu, rsigma) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -927,10 +1012,12 @@ def layernorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) + amax = jnp.zeros((1,), dtype=jnp.float32) if quantizer is None: - output, _, _, _, _, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( + output, _, _, _, updated_amax, mu, rsigma = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.LayerNorm, @@ -938,23 +1025,45 @@ def layernorm_fwd( epsilon=epsilon, out_dtype=x.dtype, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, + amax_scope=amax_scope, + transpose_batch_sequence=False, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) - return NoScaleTensor(data=output, amax=None), mu, rsigma + # cuDNN does not support amax output for non quantized output + updated_amax = ( + updated_amax + if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING) + else None + ) + return NoScaleTensor(data=output, amax=updated_amax), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, mu, rsigma = layernorm_fwd( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=False, + ) + out, _ = quantize( + out, quantizer, amax_scope=amax_scope, transpose_batch_sequence=transpose_batch_sequence ) - out, _ = _quantize_dbias_impl(out, quantizer) return out, mu, rsigma - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. out, mu, rsigma = layernorm_fwd( x=x, @@ -963,14 +1072,23 @@ def layernorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, + ) + out = quantize( + out, + quantizer=quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) return out, mu, rsigma - is_2x2x = quantizer.is_2x2x() - # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): - is_2x2x = False + # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose + q_layout = quantizer.q_layout + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -982,6 +1100,7 @@ def layernorm_fwd( ) = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.LayerNorm, @@ -989,14 +1108,16 @@ def layernorm_fwd( epsilon=epsilon, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=is_2x2x, + quantize_layout=q_layout, scale_dtype=quantizer.get_scale_dtype(), + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) quantizer.update(updated_amax) - # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1094,6 +1215,9 @@ def rmsnorm_fwd( zero_centered_gamma: bool, epsilon: float, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.TPSP, + transpose_batch_sequence: bool = False, + output_amax_when_no_scaling: bool = False, ) -> tuple[Union[jnp.ndarray, ScaledTensor], jnp.ndarray]: """Root mean square normalization forward pass with optional quantization. @@ -1105,6 +1229,8 @@ def rmsnorm_fwd( zero_centered_gamma: If True, gamma is zero-centered. epsilon: Small constant for numerical stability. quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -1117,10 +1243,19 @@ def rmsnorm_fwd( Shape: (..., 1) """ if not NormFwdPrimitive.enabled(): - return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon) + if quantizer is not None: + output = quantize( + output, + quantizer, + flatten_axis=-1, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + return (output, rsigma) # TE/common does not support normalization with colwise only quantization yet - if quantizer is not None and quantizer.q_layout == QuantizeLayout.COLWISE: + if quantizer is not None and quantizer.q_layout.is_colwise_only: return _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer) scale = ( @@ -1128,12 +1263,14 @@ def rmsnorm_fwd( if isinstance(quantizer, DelayedScaleQuantizer) else jnp.ones((1,), dtype=jnp.float32) ) + amax = jnp.zeros((1,), dtype=jnp.float32) beta = jnp.ones((1,), dtype=jnp.float32) if quantizer is None: - output, _, _, _, _, _, rsigma = NormFwdPrimitive.outer_primitive.bind( + output, _, _, _, updated_amax, _, rsigma = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.RMSNorm, @@ -1141,21 +1278,47 @@ def rmsnorm_fwd( epsilon=epsilon, out_dtype=x.dtype, scaling_mode=ScalingMode.NO_SCALING.value, - is_2x=False, + quantize_layout=QuantizeLayout.ROWWISE, scale_dtype=jnp.float32, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) - return NoScaleTensor(data=output, amax=None), rsigma + # cuDNN does not support amax output for non quantized output + updated_amax = ( + updated_amax + if output_amax_when_no_scaling and not is_norm_fwd_cudnn_enabled(ScalingMode.NO_SCALING) + else None + ) + return NoScaleTensor(data=output, amax=updated_amax), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): - out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out.data, quantizer) + out, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=False, + ) + out = quantize( + out.data, + quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) return out, rsigma - if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + if ( + quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING + or quantizer.scaling_mode.is_nvfp4_scaling + ): # Current scaling does not support fused operations. Perform norm in higher precision then quantize after. out, rsigma = rmsnorm_fwd( x=x, @@ -1163,16 +1326,23 @@ def rmsnorm_fwd( zero_centered_gamma=zero_centered_gamma, epsilon=epsilon, quantizer=None, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=True, ) - out, _ = _quantize_dbias_impl( - out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + out = quantize( + out, + quantizer=quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out, rsigma - is_2x2x = quantizer.is_2x2x() - # TE/common normalization doesn't support 2x delayed scaling - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): - is_2x2x = False + # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose + q_layout = quantizer.q_layout + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): + q_layout = QuantizeLayout.ROWWISE + ( rowwise_casted_output, colwise_casted_output, @@ -1184,6 +1354,7 @@ def rmsnorm_fwd( ) = NormFwdPrimitive.outer_primitive.bind( x, scale, + amax, gamma, beta, norm_type=NVTE_Norm_Type.RMSNorm, @@ -1191,14 +1362,16 @@ def rmsnorm_fwd( epsilon=epsilon, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - is_2x=is_2x2x, + quantize_layout=q_layout, scale_dtype=quantizer.get_scale_dtype(), + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + output_amax_when_no_scaling=output_amax_when_no_scaling, is_outer=True, ) quantizer.update(updated_amax) - # TE/common Norm doesn't support 2x delayed scaling so do 1x then JAX transpose - if quantizer.is_2x2x() and quantizer.scaling_mode.is_tensor_scaling(): + if quantizer.q_layout.is_rowwise_colwise and quantizer.scaling_mode.is_tensor_scaling(): colwise_casted_output = jnp.transpose( rowwise_casted_output, (-1, *range(rowwise_casted_output.ndim - 1)) ) @@ -1290,6 +1463,8 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, ): """Common wrapper for normalization forward pass. @@ -1306,6 +1481,8 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + transpose_batch_sequence: Indicate the sequence dimension. This only works when using current-scaling. Default is False. Returns: A tuple containing: @@ -1323,12 +1500,29 @@ def normalization_fwd( zero_centered_gamma is not supported if norm_type is 'rmsnorm'. """ if norm_type == "layernorm": - output, mu, rsigma = layernorm_fwd(x, gamma, beta, zero_centered_gamma, epsilon, quantizer) + output, mu, rsigma = layernorm_fwd( + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) elif norm_type == "rmsnorm": assert ( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" - output, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer) + output, rsigma = rmsnorm_fwd( + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) mu = None else: raise ValueError(f"{norm_type=} is not supported.") diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 78780ff9c..bd2176170 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -8,17 +8,17 @@ from functools import reduce from typing import Tuple, Optional, Union import math -from packaging import version + import jax import jax.numpy as jnp -from jax import dtypes -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax.experimental.custom_partitioning import SdyShardingRule +from jax import dtypes, ffi +from jax.experimental.custom_partitioning import SdyShardingRule, BATCHING from jax.sharding import PartitionSpec import transformer_engine_jax +from .amax import AmaxScope, calculate_amax, calculate_post_rht_amax from .base import BasePrimitive, register_primitive from .misc import ( get_padded_spec, @@ -30,7 +30,11 @@ get_min_device_compute_capability, NamedSharding, ) -from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp +from ..sharding import ( + all_reduce_max_along_all_axes_except_PP, + all_reduce_sum_along_dp_fsdp, + get_num_devices_in_mesh, +) from ..quantize import ( ScaledTensor2x, ScaledTensor, @@ -42,13 +46,9 @@ ScalingMode, compute_scale_from_amax, NoScaleTensor, + get_rht_matrix, ) -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] @@ -61,14 +61,16 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): name = "te_dbias_quantize_ffi" multiple_results = True impl_static_args = ( - 3, - 4, - 5, - 6, - 7, - 8, - 9, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer + 6, # out_dtype + 7, # scaling_mode + 8, # q_layout + 9, # flatten_axis + 10, # scale_dtype + 11, # is_dbias + 12, # is_outer + 13, # stochastic_rounding + 14, # use_rht + ) inner_primitive = None outer_primitive = None @@ -77,6 +79,9 @@ def abstract( x_aval, scale_aval, amax_aval, + sr_rng_state_aval, + post_rht_amax_aval, + rht_matrix_aval, *, out_dtype, scaling_mode, @@ -85,6 +90,8 @@ def abstract( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p abstract @@ -93,21 +100,80 @@ def abstract( assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] out_shape = x_aval.shape assert scale_aval is None or scale_aval.dtype == jnp.float32 + if stochastic_rounding: + assert ScalingMode( + scaling_mode + ).is_nvfp4_scaling, "stochastic_rounding can only be used with NVFP4 scaling modes" + # JAX doesn't support 64-bit by default so use 4x uint32 instead of 2x int64 + assert sr_rng_state_aval is not None and sr_rng_state_aval.dtype == jnp.uint32, ( + "sr_rng_state must be a uint32 array when stochastic_rounding is True but" + f" received {sr_rng_state_aval}" + ) + if is_outer and get_num_devices_in_mesh() > 1: + assert ( + sr_rng_state_aval.shape[0] == get_num_devices_in_mesh() + and sr_rng_state_aval.shape[1] == 4 + ), ( + "sr_rng_state must be of shape (num_devices, 4) when stochastic_rounding is" + f" True and is_outer is True but received {sr_rng_state_aval.shape}" + ) + else: + # We cannot assert the shape is exactly (4,) here because if the quantized data is not perfectly sharded across all devices then we will have extra rng state here. For example, this could occur when the weights are not sharded when using data parallelism. However, this is okay because the extra rng state will simply not be used and each device still has a unique rng state. + assert sr_rng_state_aval.size >= 4, ( + "Sharded sr_rng_state must have at least 4 elements per device when" + f" stochastic_rounding is True but received {sr_rng_state_aval.shape}" + ) - if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if QuantizeLayout(q_layout).has_rowwise: rowwise_out_shape = out_shape else: rowwise_out_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) + assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), ( + f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must" + f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}" + ) + updated_amax_aval = amax_aval + if use_rht: + assert ( + x_aval.dtype == jnp.bfloat16 + ), "x must be of dtype bfloat16 to be eligible for RHT cast fusion." + + if flatten_axis < 0: + flatten_axis += len(x_aval.shape) + rows = reduce(operator.mul, x_aval.shape[:flatten_axis], 1) + cols = reduce(operator.mul, x_aval.shape[flatten_axis:], 1) + assert rows % 64 == 0 and cols % 128 == 0, ( + "Rows must be multiple of 64 and cols multiple of 128 when use_rht is True to be" + f" eligible for RHT cast fusion. Received rows {rows} and cols {cols} of 2D shape" + f" from original shape of {x_aval.shape} with flatten_axis {flatten_axis}." + ) + + assert ( + rht_matrix_aval is not None + and rht_matrix_aval.dtype == jnp.bfloat16 + and rht_matrix_aval.shape == (16, 16) + ), "rht_matrix must be of shape (16, 16) and dtype bfloat16" + assert ( + post_rht_amax_aval is not None + and post_rht_amax_aval.dtype == jnp.float32 + and post_rht_amax_aval.size == 1 + ), "post_rht_amax must be of dtype float32" + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x_aval.shape, is_padded=not is_outer, flatten_axis=flatten_axis) + ).get_scale_shape_2x( + x_aval.shape, + is_padded=not is_outer, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if QuantizeLayout(q_layout).has_colwise: + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_shape = multidim_transpose(out_shape, transpose_axis=flatten_axis) else: colwise_out_shape = out_shape @@ -128,10 +194,9 @@ def abstract( gi_hidden_size, jax_dtype_to_te_dtype(x_aval.dtype), jax_dtype_to_te_dtype(out_dtype), + jax_dtype_to_te_dtype(scale_dtype), scaling_mode, - QuantizeLayout( - q_layout - ), # For now until we have auto-decoding for QuantizeLayout enum + q_layout.value, ) wkspace_shape = wkspace_info[0] wkspace_dtype = te_dtype_to_jax_dtype(wkspace_info[1]) @@ -174,6 +239,9 @@ def lowering( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, *, out_dtype, scaling_mode, @@ -182,12 +250,14 @@ def lowering( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p lowering rules """ del out_dtype, scale_dtype, is_outer - x_aval, scale_aval, amax_aval = ctx.avals_in + x_aval, scale_aval, amax_aval, _, _, _ = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] assert scale_aval.dtype == amax_aval.dtype == jnp.float32 return ffi.ffi_lowering( @@ -198,10 +268,15 @@ def lowering( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, scaling_mode=scaling_mode.value, - q_layout=q_layout, + q_layout=q_layout.value.value, flatten_axis=flatten_axis, is_dbias=is_dbias, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) @staticmethod @@ -209,6 +284,9 @@ def impl( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype, scaling_mode, q_layout, @@ -216,6 +294,8 @@ def impl( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ te_dbias_quantize_p implementation @@ -234,6 +314,9 @@ def impl( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -241,14 +324,18 @@ def impl( scale_dtype=scale_dtype, is_dbias=is_dbias, is_outer=False, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode - ).get_scale_shape_2x(x.shape, is_padded=False, flatten_axis=flatten_axis) + ).get_scale_shape_2x( + x.shape, is_padded=False, flatten_axis=flatten_axis, broadcast_2d_scale_shape_to_1d=True + ) scale_inv = jax.lax.slice( scale_inv, [0] * len(rowwise_scale_inv_shape), rowwise_scale_inv_shape ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: colwise_scale_inv = jax.lax.slice( colwise_scale_inv, [0] * len(colwise_scale_inv_shape), colwise_scale_inv_shape ) @@ -273,6 +360,8 @@ def batcher( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, ): """ to describe batch rules for vmap @@ -280,8 +369,8 @@ def batcher( del is_outer check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale, amax = batched_args - x_bdim, scale_bdim, amax_bdim = batch_dims + x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix = batched_args + x_bdim, scale_bdim, amax_bdim, _, _, _ = batch_dims out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( @@ -289,12 +378,17 @@ def batcher( x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, flatten_axis=flatten_axis, scale_dtype=scale_dtype, is_dbias=is_dbias, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ), out_bdims, ) @@ -308,11 +402,20 @@ def infer_sharding_from_operands( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, arg_infos, result_infos, ): - del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. + del ( + out_dtype, + result_infos, + scale_dtype, + is_outer, + stochastic_rounding, + use_rht, + ) # Unused. x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) @@ -321,8 +424,8 @@ def infer_sharding_from_operands( PartitionSpec(*x_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + if q_layout.has_colwise: + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -342,11 +445,19 @@ def infer_sharding_from_operands( ) scale_inv_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv_spec = scale_inv_spec + if q_layout.has_colwise: + if ( + ScalingMode(scaling_mode).is_block_scaling + and ScalingMode(scaling_mode).is_colwise_transposed + ): + colwise_scale_inv_spec = multidim_transpose( + scale_inv_spec, transpose_axis=flatten_axis + ) + else: + colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" @@ -378,11 +489,13 @@ def partition( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, arg_infos, result_infos, ): - del result_infos, is_outer + del result_infos, is_outer # Unused. x_spec = get_padded_spec(arg_infos[0]) amax_spec = get_padded_spec(arg_infos[2]) @@ -391,8 +504,9 @@ def partition( PartitionSpec(*x_spec), desc="BaseDBiasQuantizePrimitive.out_sharding", ) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): + + if q_layout.has_colwise: + if ScalingMode(scaling_mode).is_colwise_transposed: colwise_out_spec = multidim_transpose(x_spec, transpose_axis=flatten_axis) else: colwise_out_spec = x_spec @@ -412,11 +526,19 @@ def partition( ) scale_inv_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + if ScalingMode(scaling_mode).is_block_scaling: scale_inv_spec = x_spec - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - colwise_scale_inv_spec = scale_inv_spec + if q_layout.has_colwise: + if ( + ScalingMode(scaling_mode).is_block_scaling + and ScalingMode(scaling_mode).is_colwise_transposed + ): + colwise_scale_inv_spec = multidim_transpose( + scale_inv_spec, transpose_axis=flatten_axis + ) + else: + colwise_scale_inv_spec = scale_inv_spec scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" @@ -430,7 +552,13 @@ def partition( desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + arg_shardings = list(arg_i.sharding for arg_i in arg_infos) + arg_shardings[3] = NamedSharding( + mesh, + PartitionSpec(tuple(x for x in x_spec if x is not None), None), + desc="BaseDBiasQuantizePrimitive.sr_rng_state", + ) + arg_shardings = tuple(arg_shardings) out_shardings = ( out_sharding, colwise_out_sharding, @@ -440,7 +568,10 @@ def partition( dbias_sharding, ) - def sharded_impl(x, scale, amax): + def sharded_impl(x, scale, amax, sr_rng_state, post_rht_amax, rht_matrix): + if sr_rng_state.size > 4: + # See comment in abstract method for explanation of why we cannot assert exact shape + sr_rng_state = sr_rng_state.flatten()[:4] ( local_x, local_colwise_x, @@ -452,6 +583,9 @@ def sharded_impl(x, scale, amax): x, scale, amax, + sr_rng_state, + post_rht_amax, + rht_matrix, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -459,6 +593,8 @@ def sharded_impl(x, scale, amax): scale_dtype=scale_dtype, is_dbias=is_dbias, is_outer=True, + stochastic_rounding=stochastic_rounding, + use_rht=use_rht, ) if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: @@ -491,38 +627,54 @@ def shardy_sharding_rule( scale_dtype, is_dbias, is_outer, + stochastic_rounding, + use_rht, mesh, value_types, result_types, ): - if version.parse(jax.__version__) < version.parse("0.5.0"): - raise ImportError("JAX version 0.5.0 or later is required for shardy sharding.") - del out_dtype, scale_dtype, is_outer, mesh, result_types + del ( + out_dtype, + scale_dtype, + is_outer, + stochastic_rounding, + use_rht, + mesh, + result_types, + ) - prefix = "BaseDBiasQuantizePrimitive_" + prefix = "DBiasQuantize" scale_rules = ScalingMode(scaling_mode).get_shardy_sharding_rules( - len(value_types[0].shape), - unique_var=prefix + "x", + value_types[0].shape, + unique_var=prefix, flatten_axis=flatten_axis, + q_layout=q_layout, + broadcast_2d_scale_shape_to_1d=True, ) - x_axes = scale_rules.input_spec - colwise_scale_inv = scale_rules.colwise_rule - - out = x_axes - colwise_out = (prefix + "out_colwise",) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): - if ScalingMode(scaling_mode).is_tensor_scaling(): - colwise_out = tuple(multidim_transpose(x_axes, transpose_axis=flatten_axis)) - else: - colwise_out = x_axes + input_spec = scale_rules.input_spec + dbias = input_spec[flatten_axis:] if is_dbias else (prefix + "_dbias",) + amax = (BATCHING + prefix + "_amax",) + scale = (BATCHING + prefix + "_scale",) + sr_rng_state = ( + BATCHING + prefix + "_sr_rng_state_partition_axis", + BATCHING + prefix + "sr_rng_state_data_axis", + ) - dbias = x_axes[flatten_axis:] if is_dbias else (prefix + "dbias",) - amax = (prefix + "amax",) + post_rht_amax = (BATCHING + prefix + "_post_rht_amax",) + rht_matrix = (BATCHING + prefix + "_rht_matrix_1", BATCHING + prefix + "_rht_matrix_2") return SdyShardingRule( - (x_axes, ("…1",), amax), - (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), + (input_spec, scale, amax, sr_rng_state, post_rht_amax, rht_matrix), + ( + scale_rules.rowwise_out_spec, + scale_rules.colwise_out_spec, + scale_rules.rowwise_scale_spec, + scale_rules.colwise_scale_spec, + amax, + dbias, + ), + **scale_rules.factor_sizes, ) @@ -583,6 +735,8 @@ def _quantize_dbias_impl( is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, # Only works when using current-scaling + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -606,7 +760,12 @@ def _quantize_dbias_impl( # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, # fall back on the native-JAX quantize implementation PrimitiveClass = DBiasQuantizePrimitive if is_dbias else QuantizePrimitive - if quantizer.q_layout == QuantizeLayout.COLWISE or not PrimitiveClass.enabled(): + is_unsupported = quantizer.q_layout.is_colwise_only and not ( + quantizer.scaling_mode == ScalingMode.NVFP4_1D_SCALING + and hasattr(quantizer, "use_rht") + and quantizer.use_rht + ) + if is_unsupported or not PrimitiveClass.enabled(): if is_dbias: return _jax_quantize_dbias( x, @@ -627,24 +786,56 @@ def _quantize_dbias_impl( quantizer=quantizer, dq_dtype=dq_dtype, flatten_axis=flatten_axis, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias - scale = jnp.empty((), jnp.float32) - amax = None + use_rht = False + + scale = jnp.empty((1,), jnp.float32) + post_rht_amax = None + rht_matrix = jnp.empty((1, 1), jnp.bfloat16) + amax = x.amax + + if hasattr(quantizer, "use_rht") and quantizer.use_rht: + use_rht = True + rht_matrix = get_rht_matrix() + + new_amax, post_rht_amax = calculate_post_rht_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + produce_regular_amax=amax is None, + flatten_axis=flatten_axis, + ) + if amax is None: + # If amax is already calculated in a previous layer, we skip calculating it in the TE kernel + # So here we only calculate and update amax when it is not provided from a previous layer (amax is None) + amax = new_amax + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - # Globally reduce amax across all devices for current scaling so we have a single global scale. - # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this - # until the tensor is dequantized (e.g. in the GEMM). - amax = x.amax if amax is None: - amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) - scale = compute_scale_from_amax(amax, quantizer.q_dtype) + amax = calculate_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) + scale = compute_scale_from_amax(amax, quantizer.q_dtype, margin=0.0) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # Make sure to reset amax to zeros for DelayedScaling + amax = jnp.zeros((1,), jnp.float32) + elif quantizer.scaling_mode.is_nvfp4_scaling: + if amax is None: + amax = calculate_amax( + x.data, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, + ) - # Make sure amax is init with zero + # Make sure amax is not None if amax is None: amax = jnp.zeros((1,), jnp.float32) @@ -652,13 +843,20 @@ def _quantize_dbias_impl( is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( quantizer.scaling_mode.is_tensor_scaling() - and quantizer.is_2x2x() + and quantizer.q_layout.is_rowwise_colwise and is_1x_kernel_supported ) q_layout = quantizer.q_layout + if force_1x_quantization: q_layout = QuantizeLayout.ROWWISE + sr_rng_state = None + if quantizer.scaling_mode.is_nvfp4_scaling: + # Only NVFP4 scaling modes support stochastic rounding + if quantizer.stochastic_rounding_rng_state is not None: + sr_rng_state = quantizer.stochastic_rounding_rng_state + ( rowwise_casted_output, colwise_casted_output, @@ -670,19 +868,28 @@ def _quantize_dbias_impl( x.data, scale, amax, + ( + sr_rng_state + if sr_rng_state is not None + else jnp.empty((get_num_devices_in_mesh(), 1), jnp.uint32) + ), + post_rht_amax if post_rht_amax is not None else jnp.zeros((1,), jnp.float32), + rht_matrix, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=q_layout.value, + q_layout=q_layout, flatten_axis=flatten_axis, scale_dtype=quantizer.get_scale_dtype(), - is_dbias=is_dbias, + is_dbias=is_dbias if not quantizer.scaling_mode.is_nvfp4_scaling else False, is_outer=True, + stochastic_rounding=sr_rng_state is not None, + use_rht=use_rht, ) # For DelayedScaling2x, the scale buffer is shared between rowwise and colwise - if quantizer.scaling_mode.is_tensor_scaling() and quantizer.is_2x2x(): + if quantizer.scaling_mode.is_tensor_scaling() and quantizer.q_layout.is_rowwise_colwise: colwise_scale_inv = rowwise_scale_inv - if q_layout == QuantizeLayout.ROWWISE: + if q_layout.is_rowwise_only: # Quantizer requires 2x quantization, but we are using 1x quantization # for performance reasons, so we need to generate the colwise data in JAX if flatten_axis < 0: @@ -690,19 +897,23 @@ def _quantize_dbias_impl( colwise_casted_output = jnp.transpose( rowwise_casted_output, (*range(flatten_axis, x.ndim), *range(flatten_axis)) ) - quantizer.update(updated_amax) + if quantizer.scaling_mode.is_nvfp4_scaling and is_dbias: + dbias = _jax_dbias(x, flatten_axis=flatten_axis) out = ScaledTensorFactory.create( data=rowwise_casted_output, scale_inv=rowwise_scale_inv, colwise_data=colwise_casted_output, colwise_scale_inv=colwise_scale_inv, + amax=updated_amax, + colwise_amax=post_rht_amax, scaling_mode=quantizer.scaling_mode, dq_dtype=dq_dtype, q_layout=quantizer.q_layout, data_layout=quantizer.get_data_layout(), flatten_axis=flatten_axis, + colwise_has_rht_applied=use_rht, ) return out, dbias.astype(dq_dtype) @@ -711,6 +922,8 @@ def quantize( x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -721,6 +934,7 @@ def quantize( flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. is None. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. Returns: A ScaledTensor containing the quantized input tensor. @@ -729,6 +943,8 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) return out @@ -738,6 +954,8 @@ def quantize_dbias( quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, + amax_scope: AmaxScope = AmaxScope.LOCAL, + transpose_batch_sequence: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -748,6 +966,8 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. + amax_scope: Indicate the scope to run amax calculation. This only works when using current-scaling. Default is AmaxScope.LOCAL. + Returns: A tuple containing: @@ -761,6 +981,8 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, + amax_scope=amax_scope, + transpose_batch_sequence=transpose_batch_sequence, ) @@ -804,6 +1026,11 @@ def abstract( # TODO(Phuong): can scale_aval be None? assert scale_aval is None or scale_aval.dtype == jnp.float32 + assert out_dtype in ScalingMode(scaling_mode).get_compatible_q_dtypes(), ( + f"out_dtype {out_dtype} not compatible with scaling_mode {scaling_mode}. out_dtype must" + f" be one of {ScalingMode(scaling_mode).get_compatible_q_dtypes()}" + ) + rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode ).get_grouped_scale_shape_2x( @@ -814,7 +1041,7 @@ def abstract( flatten_axis=flatten_axis, ) - if q_layout in (QuantizeLayout.ROWWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_rowwise: rowwise_out_shape = out_shape else: rowwise_out_shape = (1,) @@ -823,7 +1050,7 @@ def abstract( amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) - if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): + if q_layout.has_colwise: colwise_out_shape = out_shape else: colwise_out_shape = (1,) @@ -888,7 +1115,7 @@ def lowering( scale, group_sizes, scaling_mode=scaling_mode.value, - q_layout=q_layout, + q_layout=q_layout.value.value, flatten_axis=flatten_axis, ) @@ -1002,7 +1229,7 @@ def grouped_quantize( ) grouped_amax = jax.ops.segment_max(row_amax, segment_ids, num_segments=n_groups) for i in range(n_groups): - tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype) + tmp_scale = compute_scale_from_amax(grouped_amax[i], quantizer.q_dtype, margin=0.0) scale = scale.at[i].set(tmp_scale[0]) is_tensor_scaling = quantizer.scaling_mode in ( @@ -1011,7 +1238,7 @@ def grouped_quantize( ) # WAR for tensor_scaling as TE/Common does not support q_layout = COLWISE yet # So we performance ROWWISE_COLWISE and use the colwise_tensor_output - apply_colwise_war = is_tensor_scaling and quantizer.q_layout == QuantizeLayout.COLWISE + apply_colwise_war = is_tensor_scaling and quantizer.q_layout.is_colwise_only q_layout = QuantizeLayout.ROWWISE_COLWISE if apply_colwise_war else quantizer.q_layout ( rowwise_casted_output, @@ -1025,7 +1252,7 @@ def grouped_quantize( group_sizes, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, - q_layout=q_layout.value, + q_layout=q_layout, flatten_axis=flatten_axis, group_axis=group_axis, scale_dtype=quantizer.get_scale_dtype(), @@ -1033,7 +1260,7 @@ def grouped_quantize( # For DelayedScaling2x and CurrentScaling2x, the scale buffer # is shared between rowwise and colwise - if is_tensor_scaling and quantizer.is_2x2x() or apply_colwise_war: + if is_tensor_scaling and quantizer.q_layout.is_rowwise_colwise or apply_colwise_war: colwise_scale_inv = rowwise_scale_inv # TODO(Phuong): store the whole updated_amax in the grouped_quantize instead? diff --git a/transformer_engine/jax/cpp_extensions/softmax.py b/transformer_engine/jax/cpp_extensions/softmax.py index 43cb11a08..575a2dd3a 100644 --- a/transformer_engine/jax/cpp_extensions/softmax.py +++ b/transformer_engine/jax/cpp_extensions/softmax.py @@ -6,22 +6,16 @@ from functools import partial, reduce import operator import warnings -from packaging import version import jax import jax.numpy as jnp -from jax import dtypes +from jax import dtypes, ffi from jax.sharding import PartitionSpec, NamedSharding from .base import BasePrimitive, register_primitive from .misc import get_padded_spec, check_valid_batch_dims from ..softmax import SoftmaxType -if version.parse(jax.__version__) >= version.parse("0.5.0"): - from jax import ffi # pylint: disable=ungrouped-imports -else: - from jax.extend import ffi # pylint: disable=ungrouped-imports - __all__ = [ "scaled_softmax_fwd", diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 453a4202b..2bfa4c89f 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -17,6 +17,9 @@ #include #include #include +#ifndef USE_ROCM +#include +#endif #include #include @@ -36,27 +39,38 @@ #include "transformer_engine/activation.h" #include "transformer_engine/multi_stream.h" -// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace -XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); - namespace transformer_engine { namespace jax { +struct ClampedSwigluConfig { + float limit; + float alpha; +}; + +struct ActivationConfig { + ClampedSwigluConfig clamped_swiglu; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(ActLuInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler); pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - JAXX_Scaling_Mode scaling_mode, bool is_2x); + JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout); // Normalization +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormForwardHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardInitializeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(NormBackwardHandler); pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, @@ -78,9 +92,9 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, + DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, - QuantizeLayout q_layout); + JAXX_Quantize_Layout quantize_layout); // Softmax XLA_FFI_DECLARE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler); @@ -125,11 +139,17 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // GEMM XLA_FFI_DECLARE_HANDLER_SYMBOL(GemmHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(CollectiveGemmInitHandler); // Grouped GEMM +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedGemmHandler); #ifndef USE_ROCM +// Amax +XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationInitializeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(RHTAmaxCalculationHandler); + // Cudnn helpers XLA_FFI_DECLARE_HANDLER_SYMBOL(CudnnHandleInitHandler); @@ -140,4 +160,17 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(CublasHandleInitHandler); } // namespace jax } // namespace transformer_engine +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING(transformer_engine::jax::ClampedSwigluConfig, + ::xla::ffi::StructMember("limit"), + ::xla::ffi::StructMember("alpha")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::ActivationConfig, + ::xla::ffi::StructMember("clamped_swiglu")); + +// ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Collective_Op); +XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Quantize_Layout); + #endif // TRANSFORMER_ENGINE_JAX_CSRC_FP8_MODULES_H_ diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 17fa9906b..34ce29ae1 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -15,10 +15,15 @@ namespace transformer_engine { namespace jax { Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, + Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, - bool is_2x_int) { + Result_Type updated_amax_buf, int64_t act_enum, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params, + bool output_amax_when_no_scaling) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); @@ -27,14 +32,15 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto *output = output_buf->untyped_data(); auto *colwise_output = colwise_output_buf->untyped_data(); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto input_dims = input_buf.dimensions(); auto m = product(input_dims, 0, input_dims.size() - 2); auto n = input_dims.back(); auto act_type = static_cast(act_enum); auto act_len = input_dims[input_dims.size() - 2]; - auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis auto input_shape = std::vector{m, static_cast(act_len * n)}; @@ -42,7 +48,12 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto output_trans_shape = std::vector{static_cast(n), m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } NVTE_CHECK( scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, @@ -52,10 +63,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); @@ -69,7 +77,7 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal } } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; @@ -125,6 +133,10 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::SREGLU: nvte_sreglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_swiglu(input_tensor.data(), output_tensor.data(), swiglu_limit, swiglu_alpha, + stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -138,19 +150,53 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuHandler, ActLuFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv .Ret() // scale_inv colwise - .Ret() // amax + .Ret() // updated_amax .Attr("act_enum") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("quantize_layout") + .Attr("act_params") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); +Error_Type ActLuInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + int64_t act_enum, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, ActivationConfig act_params, + bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(ActLuFFI), stream, input_buf, scale_buf, amax_buf, + output_buf, colwise_output_buf, scale_inv_buf, colwise_scale_inv_buf, + updated_amax_buf, act_enum, scaling_mode, quantize_layout, act_params, + output_amax_when_no_scaling); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(ActLuInitializeHandler, ActLuInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // scale + .Arg() // amax + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // updated_amax + .Attr("act_enum") + .Attr("scaling_mode") + .Attr("quantize_layout") + .Attr("act_params") + .Attr("output_amax_when_no_scaling")); + pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType out_dtype, - JAXX_Scaling_Mode scaling_mode, bool is_2x) { + JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto dact_input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; @@ -182,7 +228,7 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid std::vector{1}); } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); @@ -212,11 +258,17 @@ pybind11::tuple GetDActDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hid Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type colwise_output_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type dbias_buf, - Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, - int64_t act_enum, bool is_2x, bool is_dbias) { + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, + JAXX_Scaling_Mode scaling_mode, int64_t act_enum, + JAXX_Quantize_Layout quantize_layout, bool is_dbias, + ActivationConfig act_params, bool output_amax_when_no_scaling) { + // parameters for clamped swiglu used in GPT OSS + auto swiglu_limit = act_params.clamped_swiglu.limit; + auto swiglu_alpha = act_params.clamped_swiglu.alpha; + auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); @@ -224,7 +276,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto *input = input_buf.untyped_data(); auto *act_input = act_input_buf.untyped_data(); float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); auto act_type = static_cast(act_enum); auto flatten_axis = output_buf->dimensions().size() - 2; // output has act axis @@ -267,13 +321,14 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, out_dtype, output_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } if (is_fp8_dtype(out_dtype)) { if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); @@ -287,7 +342,7 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, } } - if (is_2x) { + if (is_quantize_2x2x(quantize_layout)) { auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) ? output_trans_shape : output_shape; @@ -317,7 +372,8 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, // fused_dgated_dbias is not available, so we use dact_lu + quantize_dbias in Python instead NVTE_CHECK(!(act_len == 2 && is_dbias), "Unsupported DGatedActedDBias Fusion!"); - NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_2x && act_len == 2), + NVTE_CHECK(!(scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && + is_quantize_2x2x(quantize_layout) && act_len == 2), "TE/common does not support delayed scaling for 2x with gated activations."); if (is_dbias) { @@ -383,6 +439,10 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::SREGLU: nvte_dsreglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::CLAMPED_SWIGLU: + nvte_clamped_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), + swiglu_limit, swiglu_alpha, stream); + break; default: NVTE_ERROR("Unsupported ActivationEnum"); break; @@ -398,6 +458,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Arg() // input .Arg() // act input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -407,8 +468,47 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeHandler, DActLuDBiasQuantizeFFI .Ret() // wkspace .Attr("scaling_mode") .Attr("act_enum") - .Attr("is_2x") - .Attr("is_dbias"), + .Attr("quantize_layout") + .Attr("is_dbias") + .Attr("act_params") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); + +Error_Type DActLuDBiasQuantizeInitializeFFI( + cudaStream_t stream, Buffer_Type input_buf, Buffer_Type act_input_buf, Buffer_Type scale_buf, + Buffer_Type amax_buf, Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + int64_t act_enum, JAXX_Quantize_Layout quantize_layout, bool is_dbias, + ActivationConfig act_params, bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(DActLuDBiasQuantizeFFI), stream, input_buf, + act_input_buf, scale_buf, amax_buf, output_buf, colwise_output_buf, + scale_inv_buf, colwise_scale_inv_buf, updated_amax_buf, dbias_buf, + workspace_buf, scaling_mode, act_enum, quantize_layout, is_dbias, + act_params, output_amax_when_no_scaling); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(DActLuDBiasQuantizeInitializeHandler, + DActLuDBiasQuantizeInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Arg() // act input + .Arg() // scale + .Arg() // amax + .Ret() // output + .Ret() // colwise output + .Ret() // scale_inv + .Ret() // scale_inv colwise + .Ret() // updated_amax + .Ret() // dbias + .Ret() // wkspace + .Attr("scaling_mode") + .Attr("act_enum") + .Attr("quantize_layout") + .Attr("is_dbias") + .Attr("act_params") + .Attr("output_amax_when_no_scaling")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/amax.cpp b/transformer_engine/jax/csrc/extensions/amax.cpp new file mode 100644 index 000000000..aa40a8e35 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/amax.cpp @@ -0,0 +1,104 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ +#ifndef USE_ROCM +#include + +#include + +#include "../extensions.h" +#include "transformer_engine/cast.h" +#include "transformer_engine/hadamard_transform.h" +#include "transformer_engine/recipe.h" +#include "transformer_engine/transformer_engine.h" +#include "xla/ffi/api/c_api.h" + +namespace transformer_engine { +namespace jax { + +Error_Type RHTAmaxCalculationFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type amax_buf, + Result_Type post_rht_amax_buf, + int64_t rht_matrix_random_sign_mask_t, bool produce_regular_amax, + int64_t flatten_axis) { + NVTE_CHECK(input_buf.untyped_data() != nullptr, + "Input must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(input_buf.element_type()) == DType::kBFloat16, + "Input must be of type bfloat16 for RHT Amax calculation"); + + NVTE_CHECK(flatten_axis > 0 && flatten_axis < static_cast(input_buf.dimensions().size()), + "Flatten axis is out of bounds"); + TensorWrapper input_tensor(input_buf.untyped_data(), + std::vector{product(input_buf.dimensions(), 0, flatten_axis), + product(input_buf.dimensions(), flatten_axis, + input_buf.dimensions().size())}, + convert_ffi_datatype_to_te_dtype(input_buf.element_type())); + + float *amax_out = nullptr; + if (produce_regular_amax) { + amax_out = reinterpret_cast(amax_buf->untyped_data()); + NVTE_CHECK(amax_out != nullptr, "Amax output must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(amax_buf->element_type()) == DType::kFloat32, + "Amax output must be of type float32 for RHT Amax calculation"); + NVTE_CHECK(amax_buf->dimensions().size() == 1 && amax_buf->dimensions()[0] == 1, + "Amax output must be a single float for RHT Amax calculation"); + } + + float *post_rht_amax_out = reinterpret_cast(post_rht_amax_buf->untyped_data()); + NVTE_CHECK(post_rht_amax_out != nullptr, + "Post-RHT Amax output must be provided for RHT Amax calculation"); + NVTE_CHECK(convert_ffi_datatype_to_te_dtype(post_rht_amax_buf->element_type()) == DType::kFloat32, + "Post-RHT Amax output must be of type float32 for RHT Amax calculation"); + NVTE_CHECK(post_rht_amax_buf->dimensions().size() == 1 && post_rht_amax_buf->dimensions()[0] == 1, + "Post-RHT Amax output must be a single float for RHT Amax calculation"); + + TensorWrapper out_tensor{}; + out_tensor.set_amax(amax_out, DType::kFloat32, std::vector{1}); + out_tensor.set_columnwise_amax(post_rht_amax_out, DType::kFloat32, std::vector{1}); + + // Zero'ing of amaxes is handled by TE common inside nvte_hadamard_transform_amax + nvte_hadamard_transform_amax(input_tensor.data(), out_tensor.data(), + 0, // Regular amax for rowwise does not apply RHT so mask is 0 + rht_matrix_random_sign_mask_t, stream); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RHTAmaxCalculationHandler, RHTAmaxCalculationFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // amax + .Ret() // post_rht_amax + .Attr("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t + .Attr("produce_regular_amax") // produce_regular_amax + .Attr("flatten_axis"), // flatten_axis + FFI_CudaGraph_Traits); + +Error_Type RHTAmaxCalculationInitializeFFI(cudaStream_t stream, Buffer_Type input_buf, + Result_Type amax_buf, Result_Type post_rht_amax_buf, + int64_t rht_matrix_random_sign_mask_t, + bool produce_regular_amax, int64_t flatten_axis) { + return wrapInStreamCapture(std::function(RHTAmaxCalculationFFI), stream, input_buf, amax_buf, + post_rht_amax_buf, rht_matrix_random_sign_mask_t, produce_regular_amax, + flatten_axis); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + RHTAmaxCalculationInitializeHandler, RHTAmaxCalculationInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // input + .Ret() // amax + .Ret() // post_rht_amax + .Attr("rht_matrix_random_sign_mask_t") // rht_matrix_random_sign_mask_t + .Attr("produce_regular_amax") // produce_regular_amax + .Attr("flatten_axis")); // flatten_axis + +} // namespace jax +} // namespace transformer_engine +#endif // #ifndef USE_ROCM diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 342953746..1281eb272 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -20,10 +20,12 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend(bool is_training, DType q_dtype, DTy size_t q_max_seqlen, size_t kv_max_seqlen, size_t qk_head_dim, size_t v_head_dim, int64_t window_size_left, int64_t window_size_right) { + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; auto backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, q_attn_heads, kv_attn_heads, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false, false); return backend; } @@ -135,17 +137,8 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t v_head_dim, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; @@ -160,12 +153,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -183,34 +178,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal to kv_max_seqlen"); - nvte_fused_attn_fwd_qkvpacked( - qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - ragged_offset_tensor.data(), ragged_offset_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), dummy_rng_state_tensor.data(), q_max_seqlen, - kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, - mask_type, window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), ragged_offset_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported QKVLayout."); - } + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), ragged_offset_tensor.data(), + ragged_offset_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_output_tensors); @@ -279,10 +254,16 @@ static void FusedAttnForwardImpl( /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); + + auto dummy_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; + auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false, false); nvte_populate_rng_state_async(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); /* Auxiliary tensors (to be propagated to the backward pass later) */ @@ -294,45 +275,57 @@ static void FusedAttnForwardImpl( /* Call the underlying NVTE API */ auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector{1}, DType::kInt32); + + // Prepare Q, K, V pointers and shapes based on layout + // Python passes dummy tensors for unused slots, so we extract from the actual packed data + void *q_ptr = q; + void *k_ptr = k; + void *v_ptr = v; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, workspace_tensor.data(), stream); + // QKV packed in q: [batch*seqlen, 3, heads, dim] + // Python passes: q=packed_qkv, k=dummy, v=dummy + // Extract K and V pointers from the packed q data + NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); + NVTE_CHECK(qk_head_dim == v_head_dim, + "For QKV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim); + q_ptr = q; + k_ptr = static_cast(static_cast(q) + stride); + v_ptr = static_cast(static_cast(q) + 2 * stride); + // For packed QKV, all have same shape since they're views into the same packed tensor + k_shape = q_shape; + v_shape = q_shape; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - nvte_fused_attn_fwd_kvpacked( - q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), - &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), - dummy_page_table_tensor.data(), rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - nvte_fused_attn_fwd( - q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), - dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), rng_state_tensor.data(), - q_max_seqlen, kv_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); + // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] + // Python passes: q=query, k=packed_kv, v=dummy + // Extract V pointer from the packed k data + NVTE_CHECK(qk_head_dim == v_head_dim, + "For KV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); + q_ptr = q; + k_ptr = k; + v_ptr = static_cast(static_cast(k) + stride); + // V has same shape as K since they're packed together + v_shape = k_shape; } + // else NVTE_HD_HD_HD: pointers and shapes already correct + + auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); + auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); + auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); + + nvte_fused_attn_fwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + dummy_softmax_offset_tensor.data(), s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), + k_seq_offsets_tensor.data(), dummy_page_table_tensor.data(), dummy_page_table_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, false, false, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, workspace_tensor.data(), stream); nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -415,20 +408,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training, bool deterministic, size_t max_segments_per_seq, int64_t window_size_left, int64_t window_size_right) { - // For qkv_packed - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - - // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, v_head_dim}; - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - - // For separate q, k, v auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); @@ -451,7 +433,6 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper query_workspace_tensor; - auto layout_group = nvte_get_qkv_layout_group(qkv_layout); auto is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD; // It is a WAR to pre-create all possible cuDNN graph at the JIT compile time size_t max_num_segments = is_ragged ? input_batch * max_segments_per_seq : input_batch; @@ -461,6 +442,9 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( // For cuDNN < 9.3.0, it requires to run all possible seqlens to address act_seqlen = 0 min_num_segments = input_batch * max_segments_per_seq; } + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; for (auto num_segments = min_num_segments; num_segments <= max_num_segments; ++num_segments) { // the last one is the largest which will be the returned workspace size auto q_cu_seqlens_tensor = @@ -469,41 +453,18 @@ pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); auto dummy_ragged_offset_tensor = TensorWrapper(nullptr, std::vector{num_segments + 1}, DType::kInt32); - if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, query_workspace_tensor.data(), nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_ragged_offset_tensor.data(), dummy_ragged_offset_tensor.data(), q_max_seqlen, - kv_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, query_workspace_tensor.data(), - nullptr); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), - dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - window_size_left, window_size_right, deterministic, - query_workspace_tensor.data(), nullptr); - } else { - NVTE_ERROR("Unsupported qkv_layout."); - } + + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), + dbias_tensor.data(), dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), dummy_ragged_offset_tensor.data(), + dummy_ragged_offset_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, + dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, window_size_left, + window_size_right, deterministic, false, query_workspace_tensor.data(), nullptr); } nvte_tensor_pack_destroy(&aux_input_tensors); @@ -532,86 +493,99 @@ static void FusedAttnBackwardImpl( /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); + auto dummy_d_softmax_offset_tensor = + TensorWrapper(nullptr, std::vector{1}, DType::kFloat32); + NVTE_Softmax_Type softmax_type = NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX; /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); auto backend = nvte_get_fused_attn_backend( is_training, static_cast(dtype), static_cast(dtype), qkv_layout, - bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, - kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right); + bias_type, mask_type, softmax_type, dropout_probability, attn_heads, num_gqa_groups, + q_max_seqlen, kv_max_seqlen, qk_head_dim, v_head_dim, window_size_left, window_size_right, + false, false); PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, input_batch, bias_batch, attn_heads, bias_heads, q_max_seqlen, kv_max_seqlen, dtype, backend, softmax_aux, rng_state, bias); /* Call the underly NVTE API */ + // Prepare Q, K, V pointers and shapes based on layout + void *q_ptr = q; + void *k_ptr = k; + void *v_ptr = v; + void *dq_ptr = dq; + void *dk_ptr = dk; + void *dv_ptr = dv; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; + auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { - auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, qk_head_dim}; - auto qkv_tensor = TensorWrapper(q, qkv_shape, dtype); - auto dqkv_tensor = TensorWrapper(dq, qkv_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(qkv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); + // QKV packed in q: [batch*seqlen, 3, heads, dim] + NVTE_CHECK(q_max_seqlen == kv_max_seqlen, "q_max_seqlen must equal kv_max_seqlen"); + NVTE_CHECK(qk_head_dim == v_head_dim, + "For QKV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * attn_heads * qk_head_dim); + q_ptr = q; + k_ptr = static_cast(static_cast(q) + stride); + v_ptr = static_cast(static_cast(q) + 2 * stride); + dq_ptr = dq; + dk_ptr = static_cast(static_cast(dq) + stride); + dv_ptr = static_cast(static_cast(dq) + 2 * stride); + k_shape = q_shape; + v_shape = q_shape; } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto kv_shape = - std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, qk_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(k, kv_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dk, kv_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(kv_shape) * typeToSize(dtype), - stream); - } - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, window_size_right, - deterministic, workspace_tensor.data(), stream); - } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_HD_HD) { - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, qk_head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, qk_head_dim}; - auto v_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, v_head_dim}; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); - if (is_ragged) { - (void)cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * typeToSize(dtype), stream); - (void)cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * typeToSize(dtype), stream); + // Q separate, KV packed in k: [batch*seqlen, 2, num_gqa_groups, dim] + NVTE_CHECK(qk_head_dim == v_head_dim, + "For KV packed layout, qk_head_dim must equal v_head_dim"); + size_t stride = (typeToSize(dtype) * num_gqa_groups * qk_head_dim); + q_ptr = q; + k_ptr = k; + v_ptr = static_cast(static_cast(k) + stride); + dq_ptr = dq; + dk_ptr = dk; + dv_ptr = static_cast(static_cast(dk) + stride); + // V has same shape as K since they're packed together + v_shape = k_shape; + } + + auto q_tensor = TensorWrapper(q_ptr, q_shape, dtype); + auto k_tensor = TensorWrapper(k_ptr, k_shape, dtype); + auto v_tensor = TensorWrapper(v_ptr, v_shape, dtype); + auto dq_tensor = TensorWrapper(dq_ptr, q_shape, dtype); + auto dk_tensor = TensorWrapper(dk_ptr, k_shape, dtype); + auto dv_tensor = TensorWrapper(dv_ptr, v_shape, dtype); + + if (is_ragged) { + size_t dtype_size = typeToSize(dtype); + if (layout_group == NVTE_QKV_Layout_Group::NVTE_3HD) { + // For packed QKV, dq contains all gradients (dq, dk, dv) - clear all at once + cudaMemsetAsync(dq, 0, 3 * transformer_engine::jax::product(q_shape) * dtype_size, stream); + } else if (layout_group == NVTE_QKV_Layout_Group::NVTE_HD_2HD) { + // Clear dq + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream); + // For packed KV, dk contains both dk and dv - clear all at once + cudaMemsetAsync(dk, 0, 2 * transformer_engine::jax::product(k_shape) * dtype_size, stream); + } else { + // All separate - clear each individually + cudaMemsetAsync(dq, 0, transformer_engine::jax::product(q_shape) * dtype_size, stream); + cudaMemsetAsync(dk, 0, transformer_engine::jax::product(k_shape) * dtype_size, stream); + cudaMemsetAsync(dv, 0, transformer_engine::jax::product(v_shape) * dtype_size, stream); } - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_seq_offsets_tensor.data(), - k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, window_size_left, - window_size_right, deterministic, workspace_tensor.data(), stream); - } else { - NVTE_ERROR("Unsupported qkv_layout."); } + nvte_fused_attn_bwd( + q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), dbias_tensor.data(), + dummy_d_softmax_offset_tensor.data(), q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + q_seq_offsets_tensor.data(), k_seq_offsets_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, softmax_type, + window_size_left, window_size_right, deterministic, false, workspace_tensor.data(), stream); + nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp new file mode 100644 index 000000000..4d44bb4a8 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.cpp @@ -0,0 +1,263 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef USE_ROCM +#include "cgemm_helper.h" + +#include "common/util/system.h" +#include "nccl.h" + +namespace transformer_engine { +namespace jax { + +ncclUniqueId CommunicatorHandler::coordinate_nccl_unique_id(const std::string &id_type) { + ncclUniqueId unique_id; + + int tp_domain_id = get_tp_domain_id(); + bool is_tp_leader = (get_local_device_id_within_tp_domain() == 0); + + pid_t pgid = getpgid(0); + + std::string base_path = getenv("NVTE_JAX_NCCL_FILE_PATH", "/tmp"); + std::string id_file = base_path + "/nccl_" + id_type + "_unique_id_pgid_" + std::to_string(pgid) + + "_" + std::to_string(num_total_devices) + "_" + std::to_string(tp_size) + + "_domain_" + std::to_string(tp_domain_id) + ".bin"; + + if (is_tp_leader) { + NVTE_CHECK_NCCL(ncclGetUniqueId(&unique_id)); + + // Write the ID to a temporary file + std::ofstream file(id_file, std::ios::binary); + NVTE_CHECK(file.is_open(), "Failed to create NCCL unique ID file: ", id_file); + file.write(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + file.close(); + } else { + // Wait for the ID file to be created and read it + int attempts = 0; + const int max_attempts = 100; + while (attempts < max_attempts) { + std::ifstream file(id_file, std::ios::binary); + if (file.is_open()) { + file.read(reinterpret_cast(&unique_id), sizeof(ncclUniqueId)); + if (file.gcount() == sizeof(ncclUniqueId)) { + file.close(); + break; + } + file.close(); + } + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + attempts++; + } + NVTE_CHECK(attempts < max_attempts, + "Timeout waiting for " + id_type + " NCCL unique ID file from leader: ", id_file); + } + + if (is_tp_leader) { + _nccl_id_file_name.push_back(id_file); + } + + return unique_id; +} + +void CommunicatorHandler::init(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size) { + // Validate inputs + NVTE_CHECK(num_devices_per_process == 1, + "num_devices_per_process must be == 1, got num_devices_per_process=", + num_devices_per_process); + NVTE_CHECK(num_total_devices >= 1, + "num_total_devices must be >= 1, got num_total_devices=", num_total_devices); + NVTE_CHECK( + num_total_devices % num_devices_per_process == 0, + "num_total_devices must be divisible by num_devices_per_process, got num_total_devices=", + num_total_devices, ", num_devices_per_process=", num_devices_per_process); + + // Validate TP size + NVTE_CHECK(tp_size > 0, "tp_size must be > 0, got tp_size=", tp_size); + NVTE_CHECK(num_total_devices % tp_size == 0, + "num_total_devices must be divisible by tp_size, got num_total_devices=", + num_total_devices, ", tp_size=", tp_size); + + auto &handler = get(false); + handler.num_total_devices = num_total_devices; + handler.num_devices_per_process = num_devices_per_process; + handler.process_id = process_id; + handler.num_processes = num_total_devices / num_devices_per_process; + handler.tp_size = tp_size; + handler.tp_num_domains = num_total_devices / tp_size; + + // Initialize vectors with the correct size + handler.local_device_ids_within_process.resize(num_devices_per_process); + handler.local_device_ids_within_tp_domain.resize(num_devices_per_process); + handler.tp_domain_ids.resize(num_devices_per_process); + handler.global_device_ids.resize(num_devices_per_process); + handler.tp_comms.resize(num_devices_per_process); + + NVTE_CHECK(0 <= process_id && process_id < handler.num_processes, + "Invalid process_id=", process_id, ", which is out of range [0, ", + handler.num_processes, ")"); + + // Initialize local devices and calculate their global device IDs and TP topology + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + // Use the device that JAX has already assigned to this process + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + handler.local_device_ids_within_process[local_idx] = current_device; + handler.global_device_ids[local_idx] = process_id * num_devices_per_process + local_idx; + + // Calculate TP-related values for this device + int global_device_id = handler.global_device_ids[local_idx]; + if (num_devices_per_process == tp_size) { + // Scenario 1: Multi-device per process - TP domain = single process + handler.local_device_ids_within_tp_domain[local_idx] = local_idx; + handler.tp_domain_ids[local_idx] = process_id; + } else { + // Scenario 2: Single device per process - TP domain spans multiple processes + handler.local_device_ids_within_tp_domain[local_idx] = global_device_id % tp_size; + handler.tp_domain_ids[local_idx] = global_device_id / tp_size; + } + } + + ncclUniqueId tp_id = handler.coordinate_nccl_unique_id("tp"); + + NVTE_CHECK_NCCL(ncclGroupStart()); + for (int local_idx = 0; local_idx < num_devices_per_process; local_idx++) { + NVTE_CHECK_CUDA(cudaSetDevice(handler.local_device_ids_within_process[local_idx])); + int tp_local_rank = handler.local_device_ids_within_tp_domain[local_idx]; + NVTE_CHECK_NCCL( + ncclCommInitRank(&handler.tp_comms[local_idx], handler.tp_size, tp_id, tp_local_rank)); + } + NVTE_CHECK_NCCL(ncclGroupEnd()); + + // Allocate device memory for barrier operations + NVTE_CHECK_CUDA(cudaMalloc(&handler._device_barrier, sizeof(int))); + + handler._initialize = true; + + // Bootstrap UB via creating a dummy CommOverlapP2PBase object + std::vector buffer_shape{1, 1}; + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32, + JAXX_Collective_Op::ALL_GATHER); +} + +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag) { + auto &config = CgemmConfig::get(false); + config.init(num_max_streams, gemm_priority, comm_priority, num_comm_sm, use_ce, aggregate_ag); + auto &handler = CommunicatorHandler::get(false); + handler.init(num_total_devices, num_devices_per_process, process_id, tp_size); +} + +int GetCgemmNumMaxStreams() { + auto &config = CgemmConfig::get(); + return config.num_max_streams; +} + +CommOverlapCore *CollectiveGemmPlanRegistry::get_executor(std::vector buffer_shape, + DType dtype, + JAXX_Collective_Op collective_op) { + auto &comm_handler = CommunicatorHandler::get(); + auto &cgemm_config = CgemmConfig::get(); + + int device_idx = comm_handler.get_local_device_idx_for_current_device(); + int64_t plan_id = 0; + hash_combine(plan_id, buffer_shape[0], buffer_shape[1], static_cast(dtype), + static_cast(collective_op), comm_handler.tp_size, cgemm_config.num_max_streams, + cgemm_config.gemm_priority, cgemm_config.comm_priority, cgemm_config.num_comm_sm, + cgemm_config.use_ce, cgemm_config.aggregate_ag, device_idx); + + auto it = plan_map.find(plan_id); + if (it != plan_map.end()) { + return it->second.get(); + } + + if (comm_handler.num_devices_per_process == comm_handler.tp_size) { + // Multi-device per process + } else if (comm_handler.num_devices_per_process == 1) { + // Single device per process + NVTE_CHECK(comm_handler.num_total_devices % comm_handler.tp_size == 0, + "For single device per process, num_total_devices must be divisible by tp_size, " + "got num_total_devices=", + comm_handler.num_total_devices, ", tp_size=", comm_handler.tp_size); + } else { + NVTE_ERROR("Unsupported TP configuration: num_devices_per_process=", + comm_handler.num_devices_per_process, ", tp_size=", comm_handler.tp_size, + ". Supported scenarios: " + "(1) num_devices_per_process == tp_size (multi-device per process), " + "(2) num_devices_per_process == 1 (single device per process)"); + } + + std::unique_ptr executor; + executor = std::make_unique( + buffer_shape, dtype, comm_handler.get_global_rank(), comm_handler.num_total_devices, + comm_handler.get_local_device_id_within_tp_domain(), comm_handler.tp_size, + comm_handler.get_tp_domain_id(), comm_handler.get_tp_num_domains(), comm_handler.tp_size, + comm_handler.allgather_func, comm_handler.barrier_func, get_nvte_collective_op(collective_op), + cgemm_config.num_max_streams, 1 /*comm_cga_size*/, cgemm_config.gemm_priority, + cgemm_config.comm_priority, cgemm_config.num_comm_sm, true /*set_sm_margin*/, + cgemm_config.use_ce, false /*atomic_gemm*/, cgemm_config.aggregate_ag); + + CommOverlapCore *executor_ptr = executor.get(); + plan_map[plan_id] = std::move(executor); + return executor_ptr; +} + +void CommunicatorHandler::nccl_device_barrier_impl(ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using barrier"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + NVTE_CHECK_NCCL( + ncclAllReduce(_device_barrier, _device_barrier, 1, ncclInt, ncclSum, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +void CommunicatorHandler::nccl_allgather_impl(void *output_buf, size_t output_bytes, + void *input_buf, size_t input_bytes, ExtComm) { + NVTE_CHECK(_initialize, "CommunicatorHandler must be initialized before using allgather"); + + int device_idx = get_local_device_idx_for_current_device(); + ncclComm_t tp_comm = tp_comms[device_idx]; + + size_t expected_output_bytes = input_bytes * tp_size; + NVTE_CHECK(output_bytes == expected_output_bytes, "TP allgather buffer size mismatch: expected ", + expected_output_bytes, ", got ", output_bytes); + + NVTE_CHECK_NCCL(ncclAllGather(input_buf, output_buf, input_bytes, ncclChar, tp_comm, nullptr)); + cudaDeviceSynchronize(); +} + +CommunicatorHandler::CommunicatorHandler() : _device_barrier(nullptr) { + allgather_func = [this](void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm comm) { + this->nccl_allgather_impl(output_buf, output_bytes, input_buf, input_bytes, comm); + }; + barrier_func = [this](ExtComm comm) { this->nccl_device_barrier_impl(comm); }; +} + +CommunicatorHandler::~CommunicatorHandler() { + if (_initialize && !tp_comms.empty()) { + for (auto &comm : tp_comms) { + if (comm != nullptr) { + ncclCommDestroy(comm); + } + } + } + if (_device_barrier) cudaFree(_device_barrier); + + for (const auto &file_path : _nccl_id_file_name) { + std::remove(file_path.c_str()); + } +} + +} // namespace jax +} // namespace transformer_engine +#endif //#ifndef USE_ROCM diff --git a/transformer_engine/jax/csrc/extensions/cgemm_helper.h b/transformer_engine/jax/csrc/extensions/cgemm_helper.h new file mode 100644 index 000000000..03d86c168 --- /dev/null +++ b/transformer_engine/jax/csrc/extensions/cgemm_helper.h @@ -0,0 +1,193 @@ +/************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#ifndef TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ +#define TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ + +#ifndef USE_ROCM +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "../extensions.h" +#include "common/comm_gemm_overlap/userbuffers/userbuffers.h" +#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" +#include "transformer_engine/comm_gemm_overlap.h" + +namespace transformer_engine { +namespace jax { + +// Configuration singleton for CGEMM parameters +class CgemmConfig { + public: + int num_max_streams; + int gemm_priority; + int comm_priority; + int num_comm_sm; + bool use_ce; + bool aggregate_ag; + + static void init(int _num_max_streams, int _gemm_priority, int _comm_priority, int _num_comm_sm, + bool _use_ce, bool _aggregate_ag) { + auto &config = get(false); + config._initialized = true; + config.num_max_streams = _num_max_streams; + config.gemm_priority = _gemm_priority; + config.comm_priority = _comm_priority; + config.num_comm_sm = _num_comm_sm; + config.use_ce = _use_ce; + config.aggregate_ag = _aggregate_ag; + } + + static CgemmConfig &get(bool is_initialized = true) { + static thread_local CgemmConfig instance; + NVTE_CHECK( + instance._initialized == is_initialized, + "CgemmConfig must be initialized before using it, got is_initialized=", is_initialized); + return instance; + } + + CgemmConfig(const CgemmConfig &) = delete; + CgemmConfig &operator=(const CgemmConfig &) = delete; + + private: + CgemmConfig() = default; + ~CgemmConfig() = default; + bool _initialized = false; +}; + +// Forward declaration +class CollectiveGemmPlanRegistry; + +// NCCL communicator handler for collective GEMM operations +// Support both single process single device AND single process multi device +// Two scenarios: +// 1. Single process multiple devices: TP domain = process (num_devices_per_process == tp_size) +// 2. Single process single device: TP domain spans processes (num_devices_per_process == 1) +class CommunicatorHandler { + public: + int num_total_devices = -1; + int num_devices_per_process = -1; + int process_id = -1; + int num_processes = -1; + + int tp_size = -1; + int tp_num_domains = -1; + std::vector local_device_ids_within_tp_domain; + std::vector tp_domain_ids; + std::vector tp_comms; + + std::vector local_device_ids_within_process; + std::vector global_device_ids; + + int get_global_rank() const { + int device_idx = get_local_device_idx_for_current_device(); + return global_device_ids[device_idx]; + } + + void nccl_device_barrier_impl(ExtComm); + void nccl_allgather_impl(void *output_buf, size_t output_bytes, void *input_buf, + size_t input_bytes, ExtComm); + + ncclComm_t get_comm_for_current_device() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_comms[device_idx]; + } + + int get_local_device_idx_for_current_device() const { + int current_device; + NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); + for (int i = 0; i < num_devices_per_process; i++) { + if (local_device_ids_within_process[i] == current_device) { + return i; + } + } + NVTE_ERROR("Current CUDA device ", current_device, + " not found in local_device_ids_within_process"); + } + + int get_local_device_id_within_tp_domain() const { + int device_idx = get_local_device_idx_for_current_device(); + return local_device_ids_within_tp_domain[device_idx]; + } + + int get_tp_domain_id() const { + int device_idx = get_local_device_idx_for_current_device(); + return tp_domain_ids[device_idx]; + } + + int get_tp_num_domains() const { return tp_num_domains; } + + static void init(int num_total_devices, int num_devices_per_process, int process_id, int tp_size); + + private: + ncclUniqueId coordinate_nccl_unique_id(const std::string &id_type); + + public: + static CommunicatorHandler &get(bool is_initialized = true) { + static CommunicatorHandler instance; + NVTE_CHECK(instance._initialize == is_initialized, + "CommunicatorHandler._initialize=", instance._initialize, + ", is_initialized=", is_initialized); + return instance; + } + + ExtAllgatherOp allgather_func; + ExtBarrierOp barrier_func; + + CommunicatorHandler(const CommunicatorHandler &) = delete; + CommunicatorHandler &operator=(const CommunicatorHandler &) = delete; + + private: + CommunicatorHandler(); + ~CommunicatorHandler(); + + bool _initialize = false; + int *_device_barrier = nullptr; + std::vector _nccl_id_file_name; +}; + +// Plan registry for caching collective GEMM executors +class CollectiveGemmPlanRegistry { + public: + static CollectiveGemmPlanRegistry &getInstance() { + static thread_local CollectiveGemmPlanRegistry instance; + return instance; + } + + CommOverlapCore *get_executor(std::vector buffer_shape, DType dtype, + JAXX_Collective_Op collective_op); + + private: + CollectiveGemmPlanRegistry() {} + CollectiveGemmPlanRegistry(const CollectiveGemmPlanRegistry &) = delete; + CollectiveGemmPlanRegistry &operator=(const CollectiveGemmPlanRegistry &) = delete; + + std::unordered_map> plan_map; +}; + +// Function declarations +void InitializeCgemmCommunicator(int num_total_devices, int num_devices_per_process, int process_id, + int tp_size, int num_max_streams, int gemm_priority, + int comm_priority, int num_comm_sm, bool use_ce, + bool aggregate_ag); + +int GetCgemmNumMaxStreams(); + +} // namespace jax +} // namespace transformer_engine + +#endif // #ifndef USE_ROCM +#endif // TRANSFORMER_ENGINE_JAX_CGEMM_HELPER_H_ diff --git a/transformer_engine/jax/csrc/extensions/ffi.cpp b/transformer_engine/jax/csrc/extensions/ffi.cpp index 26764e8af..9cc33e538 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.cpp +++ b/transformer_engine/jax/csrc/extensions/ffi.cpp @@ -49,6 +49,9 @@ DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType &type) { case xla::ffi::DataType::F8E8M0FNU: return DType::kFloat8E8M0; break; + case xla::ffi::DataType::F4E2M1FN: + return DType::kFloat4E2M1; + break; default: auto type_num = static_cast(type); NVTE_ERROR("TE does not support conversion of XLA_FFI_DataType %d", diff --git a/transformer_engine/jax/csrc/extensions/ffi.h b/transformer_engine/jax/csrc/extensions/ffi.h index 852a67c6c..0fc2e8389 100644 --- a/transformer_engine/jax/csrc/extensions/ffi.h +++ b/transformer_engine/jax/csrc/extensions/ffi.h @@ -24,6 +24,7 @@ using FFI_Stream_Type = xla::ffi::PlatformStream; using Dictionary = xla::ffi::Dictionary; constexpr auto FFI_Prepare = xla::ffi::ExecutionStage::kPrepare; +constexpr auto FFI_Initialize = xla::ffi::ExecutionStage::kInitialize; constexpr auto FFI_CudaGraph_Traits = {xla::ffi::Traits::kCmdBufferCompatible}; DType convert_ffi_datatype_to_te_dtype(const xla::ffi::DataType& type); @@ -101,10 +102,26 @@ inline static size_t te_dtype_bytes(const DType& type) { return 1; case DType::kFloat8E8M0: return 1; + case DType::kFloat4E2M1: + return 1; default: NVTE_ERROR("Unsupported DType: ", static_cast(type)); } } +template +Error_Type wrapInStreamCapture(std::function func, + cudaStream_t stream, Args... args) { + cudaGraph_t graph{}; + NVTE_CHECK_CUDA(cudaStreamBeginCapture(stream, cudaStreamCaptureModeRelaxed)); + + Error_Type error = func(stream, std::forward(args)...); + + NVTE_CHECK_CUDA(cudaStreamEndCapture(stream, &graph)); + NVTE_CHECK_CUDA(cudaGraphDestroy(graph)); + + return error; +} + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 7015c2f5e..d35b2d072 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -9,13 +9,23 @@ #include #include +#include +#include #include #include #include "../extensions.h" +#ifndef USE_ROCM +#include "cgemm_helper.h" +#endif //#ifndef USE_ROCM +#include "common.h" #include "common/util/cuda_runtime.h" #include "common/util/string.h" #include "common/util/system.h" +#include "cuda_runtime.h" +#ifndef USE_ROCM +#include "nccl.h" +#endif //#ifndef USE_ROCM #include "transformer_engine/swizzle.h" #include "xla/ffi/api/c_api.h" @@ -38,8 +48,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { } std::tuple> xla_buffer_to_nvte_gemm_operand( - cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, - size_t axis_boundary, bool rowwise) { + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, uint8_t *swizzle_scale_ptr, + JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { // Set tensor data with collapsed 2D shape auto buffer_dims = buffer.dimensions(); std::vector input_shape = {product(buffer_dims, 0, axis_boundary), @@ -55,69 +65,169 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( // Set scaling factor for quantized tensors if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { - NVTE_CHECK(typeToSize(input_dtype) == 1, "Quantized GEMM requires 8-bit operands."); + NVTE_CHECK(is_nvfp4_scaling(scaling_mode) || typeToSize(input_dtype) == 1, + "Quantized GEMM requires 4-bit or 8-bit operands."); NVTE_CHECK(scale_inv.element_count() > 0, "Missing inverse scaling factor for quantized GEMM."); std::vector scale_shape = {1}; - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + auto is_nvfp4 = is_nvfp4_scaling(scaling_mode); + auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING || is_nvfp4) { // Block scaling also needs to be collapsed to match 2D data scale_shape = {product(scale_inv.dimensions(), 0, axis_boundary), product(scale_inv.dimensions(), axis_boundary, scale_inv.dimensions().size())}; + NVTE_CHECK(typeToSize(scale_dtype) == 1, + "Inverse scale factors need to have an 8-bit data type."); } - - auto scale_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - if (rowwise) { + if (!is_nvfp4) { + if (rowwise) { + input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } else { + input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + } + } else { // Swizzle for NVFP4 +#ifdef USE_ROCM + NVTE_ERROR("ROCm TE does not support NVFP4 yet."); + } +#else + NVTE_CHECK(rowwise, "NVFP4 GEMM expects rowwise for both LHS and RHS"); input.set_rowwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); + // Create tensor to hold swizzled scale factor + TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); + output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); + output.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); + // Launch swizzle kernel + nvte_swizzle_scaling_factors(input.data(), output.data(), stream); + // Set swizzled scales into the input tensor + input.set_rowwise_scale_inv(swizzle_scale_ptr, scale_dtype, scale_shape); } +#endif // #ifdef USE_ROCM } return std::make_tuple(std::move(input), input_shape); } +#ifndef USE_ROCM +Error_Type CollectiveGemmInitFFI(Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, + Buffer_Type rhs_scale_inv, Buffer_Type bias, + Buffer_Type gelu_input, Buffer_Type alpha, Buffer_Type beta, + Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, + JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + nvte_cublas_handle_init(); + + // Init UB buffer + if (collective_op != JAXX_Collective_Op::NONE) { + auto &comm_handler = CommunicatorHandler::get(); + std::vector lhs_shape = { + product(lhs.dimensions(), 0, lhs_axis_boundary), + product(lhs.dimensions(), lhs_axis_boundary, lhs.dimensions().size())}; + std::vector rhs_shape = { + product(rhs.dimensions(), 0, rhs_axis_boundary), + product(rhs.dimensions(), rhs_axis_boundary, rhs.dimensions().size())}; + + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], + (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; + + std::vector buffer_shape{0, 0}; + DType buffer_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + } + auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, buffer_dtype, + collective_op); + } + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(CollectiveGemmInitHandler, CollectiveGemmInitFFI, + FFI::Bind() + .Arg() // lhs + .Arg() // lhs_scale_inv + .Arg() // rhs + .Arg() // rhs_scale_inv + .Arg() // bias + .Arg() // gelu_input + .Arg() // alpha + .Arg() // beta + .Ret() // output + .Ret() // bias_grad + .Ret() // pre_gelu_out + .Ret() // workspace + .Attr("scaling_mode") + .Attr("lhs_axis_boundary") + .Attr("rhs_axis_boundary") + .Attr("lhs_transposed") + .Attr("rhs_transposed") + .Attr("fuse_bias") + .Attr("fuse_gelu") + .Attr("grad") + .Attr("use_split_accumulator") + .Attr("collective_op")); +#endif //#ifndef USE_ROCM + Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, - Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, - int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, - bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type bias_grad, + Result_Type pre_gelu_out, Result_Type workspace, JAXX_Scaling_Mode scaling_mode, + int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, + bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, + bool use_split_accumulator, JAXX_Collective_Op collective_op) { + // cuBLAS workspace + 256 alignment enforcement (+ swizzle scales) + uint8_t *lhs_swizzle_scale_ptr = nullptr, *rhs_swizzle_scale_ptr = nullptr; + auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); + workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); + size_t workspace_size = static_cast(workspace->element_count()) - 256; + if (is_nvfp4_scaling(scaling_mode)) { + auto lhs_scale_size = product(lhs_scale_inv.dimensions()); + auto rhs_scale_size = product(rhs_scale_inv.dimensions()); + workspace_size = workspace_size - lhs_scale_size - rhs_scale_size; + lhs_swizzle_scale_ptr = workspace_ptr; + rhs_swizzle_scale_ptr = workspace_ptr + lhs_scale_size; + workspace_ptr = rhs_swizzle_scale_ptr + rhs_scale_size; + } + auto workspace_ = TensorWrapper(workspace_ptr, std::vector{workspace_size}, DType::kByte); + // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode, - lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, - rhs_axis_boundary, make_rhs_rowwise); - // Output tensor + auto [lhs_, lhs_shape] = + xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, lhs_swizzle_scale_ptr, + scaling_mode, lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = + xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, rhs_swizzle_scale_ptr, + scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], (rhs_transposed) ? rhs_shape[0] : rhs_shape[1]}; auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); - NVTE_CHECK(out_.numel() == output->element_count(), - "cuBLAS GEMM output buffer size is incorrect, " - "expected ", - out_.numel(), " elements ", to_string_like(out_shape), " but got ", - output->element_count(), " elements ", to_string_like(output->dimensions())); // Bias input to forward pass or bias gradient output from backward pass void *bias_ptr = nullptr; - std::vector bias_shape = {0}; + size_t bias_size = 0; DType bias_dtype = out_dtype; if (fuse_bias) { - if (!grad) { + if (grad) { NVTE_CHECK(bias_grad->untyped_data() == bias.untyped_data(), "Missing operand-output aliasing in GemmPrimitive: bias <-> bias_grad"); } - bias_ptr = bias_grad->untyped_data(); - bias_shape.at(0) = bias_grad->dimensions().front(); - bias_dtype = convert_ffi_datatype_to_te_dtype(bias_grad->element_type()); + bias_ptr = bias.untyped_data(); + bias_size = product(bias.dimensions()); + bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); } - auto bias_ = TensorWrapper(bias_ptr, bias_shape, bias_dtype); + auto bias_ = TensorWrapper(bias_ptr, std::vector{bias_size}, bias_dtype); // Pre-GeLU output from forward pass or input to backward pass void *pre_gelu_ptr = nullptr; @@ -135,17 +245,96 @@ Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_i } auto pre_gelu_ = TensorWrapper(pre_gelu_ptr, pre_gelu_shape, pre_gelu_dtype); - // cuBLAS workspace + 256 alignment enforcement - auto workspace_ptr = reinterpret_cast(workspace->untyped_data()); - workspace_ptr = move_ptr_to_next_256B_aligned(workspace_ptr); - std::vector workspace_shape = {static_cast(workspace->element_count()) - 256}; - auto workspace_ = TensorWrapper(workspace_ptr, workspace_shape, DType::kByte); - - // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - nvte_cublas_gemm(rhs_.data(), lhs_.data(), out_.data(), bias_.data(), pre_gelu_.data(), - rhs_transposed, lhs_transposed, grad, workspace_.data(), false, - use_split_accumulator, num_math_sm, stream); + + float one = 1.; + float zero = 0.; + // alpha, beta + float *alpha_ptr = &one, *beta_ptr = &zero; + if (is_nvfp4_scaling(scaling_mode)) { + NVTE_CHECK(alpha.element_count() == 1 && + convert_ffi_datatype_to_te_dtype(alpha.element_type()) == DType::kFloat32); + alpha_ptr = reinterpret_cast(alpha.untyped_data()); + NVTE_CHECK(beta.element_count() == 1 && + convert_ffi_datatype_to_te_dtype(beta.element_type()) == DType::kFloat32); + beta_ptr = reinterpret_cast(beta.untyped_data()); + } + + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sm); + if (fuse_bias) config.set_bias_tensor(bias_.data()); + if (fuse_gelu) { + config.set_with_gelu_epilogue(true); + config.set_epilogue_aux_tensor(pre_gelu_.data()); + } + + if (collective_op == JAXX_Collective_Op::NONE) { + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), " elements ", + to_string_like(out_shape), " but got ", output->element_count(), " elements ", + to_string_like(output->dimensions())); + NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, + ", out_shape[1]=", out_shape[1]); + + // Launch TE/common kernel with swapped LHS/RHS for cuBLAS column-major order + nvte_cublas_gemm_v2(rhs_transposed /*transa*/, lhs_transposed /*transb*/, alpha_ptr, + rhs_.data() /*A*/, lhs_.data() /*B*/, beta_ptr, out_.data() /*C*/, + out_.data() /*D*/, workspace_.data(), config, stream); + } else { +#ifdef USE_ROCM + //TODO: better assert + NVTE_ERROR("ROCm TE jax does not integrate userbuffer for now"); +#else + std::vector buffer_shape{0, 0}; + DType buffer_dtype = out_dtype; + auto &comm_handler = CommunicatorHandler::get(); + if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + buffer_shape[0] = lhs_shape[0] * comm_handler.tp_size; + buffer_shape[1] = lhs_shape[1]; + out_shape[0] = out_shape[0] * comm_handler.tp_size; + buffer_dtype = convert_ffi_datatype_to_te_dtype(lhs.element_type()); + } else if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + buffer_shape[0] = out_shape[0]; + buffer_shape[1] = out_shape[1]; + out_shape[0] = out_shape[0] / comm_handler.tp_size; + } + NVTE_CHECK(!fuse_bias || bias_size == out_shape[1], "bias_size=", bias_size, + ", out_shape[1]=", out_shape[1]); + auto executor = CollectiveGemmPlanRegistry::getInstance().get_executor( + buffer_shape, buffer_dtype, collective_op); + if (collective_op == JAXX_Collective_Op::REDUCE_SCATTER) { + auto ubuf_out_ = TensorWrapper(executor->get_ubuf_dptr(), buffer_shape, out_dtype); + // Prepare the auxiliary buffer for the reduce-scattered GEMM output + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + + // Launch GEMM+RS + executor->split_overlap_rs(rhs_, rhs_transposed, lhs_, lhs_transposed, ubuf_out_, bias_, + pre_gelu_, workspace_, grad, false, use_split_accumulator, out_, + stream); + + } else if (collective_op == JAXX_Collective_Op::ALL_GATHER) { + auto aux_out_ = TensorWrapper(nullptr, std::vector{0}, out_dtype); // Empty + + auto out_ = TensorWrapper(output->untyped_data(), out_shape, out_dtype); + NVTE_CHECK(out_.numel() == output->element_count(), + "cuBLAS GEMM output buffer size is incorrect, expected ", out_.numel(), + " elements ", to_string_like(out_shape), " but got ", output->element_count(), + " elements ", to_string_like(output->dimensions())); + // Copy the distributed LHS operand into the local chunk of the communication buffer + executor->copy_into_buffer(stream, lhs_, true, make_lhs_rowwise); + // Launch AG+GEMM + executor->split_overlap_ag(rhs_, rhs_transposed, lhs_, lhs_transposed, out_, bias_, pre_gelu_, + workspace_, grad, false, use_split_accumulator, aux_out_, stream); + } +#endif //#ifdef USE_ROCM + } return ffi_with_cuda_error_check(); } @@ -159,6 +348,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Arg() // rhs_scale_inv .Arg() // bias .Arg() // gelu_input + .Arg() // alpha + .Arg() // beta .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out @@ -171,15 +362,75 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_bias") .Attr("fuse_gelu") .Attr("grad") - .Attr("use_split_accumulator"), + .Attr("use_split_accumulator") + .Attr("collective_op"), GemmFFI_CudaGraph_Traits); +size_t GroupedGemmGetGroupSizes(cudaStream_t stream, size_t num_gemms, int32_t *dev_group_sizes, + int32_t *host_group_sizes) { + static std::once_flag init_flag; + static cudaEvent_t d2h_event; + static size_t host_num_gemms; + static const size_t max_num_gemms = 1024; + //static int32_t host_group_sizes_internal[max_num_gemms]; + static int32_t *host_group_sizes_internal = nullptr; + auto init = [&]() { + NVTE_CHECK_CUDA(cudaEventCreate(&d2h_event)); + NVTE_CHECK_CUDA(cudaMallocHost(&host_group_sizes_internal, sizeof(int32_t) * max_num_gemms)); + }; + std::call_once(init_flag, init); + + NVTE_CHECK(dev_group_sizes == nullptr || host_group_sizes == nullptr, + "Only one of dev_group_sizes and host_group_sizes can be non-nullptr."); + + if (dev_group_sizes != nullptr) { + NVTE_CHECK(num_gemms <= max_num_gemms, "num_gemms ", num_gemms, " exceeds the maximum ", + "supported number ", max_num_gemms, " to be downloaded in advance."); + host_num_gemms = num_gemms; + // Wait for current compute stream to finish + cudaStream_t compute_stream_0 = nvte_get_compute_stream(0); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(compute_stream_0, d2h_event)); + // Async copy group_sizes from device to host + size_t copy_bytes = sizeof(int32_t) * num_gemms; + NVTE_CHECK_CUDA(cudaMemcpyAsync(host_group_sizes_internal, dev_group_sizes, copy_bytes, + cudaMemcpyDeviceToHost, compute_stream_0)); + NVTE_CHECK_CUDA(cudaEventRecord(d2h_event, compute_stream_0)); + return num_gemms; + } + + if (host_group_sizes != nullptr) { + if (host_num_gemms == 0) return 0; + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the previous value ", host_num_gemms, "."); + // Wait for the async copy to finish, then copy group_sizes to user buffer + // Note: This may break cudaGraph. + NVTE_CHECK_CUDA(cudaEventSynchronize(d2h_event)); + memcpy(host_group_sizes, host_group_sizes_internal, sizeof(int32_t) * host_num_gemms); + return host_num_gemms; + } +} + +Error_Type GroupedGemmD2HGroupSizesFFI(cudaStream_t stream, Buffer_Type group_sizes, + Result_Type dummy_output, size_t num_gemms) { + int32_t *dev_group_sizes = reinterpret_cast(group_sizes.untyped_data()); + GroupedGemmGetGroupSizes(stream, num_gemms, dev_group_sizes, nullptr); + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmD2HGroupSizesHandler, GroupedGemmD2HGroupSizesFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // group_sizes + .Ret() // dummy_output + .Attr("num_gemms")); + Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad) { + bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -300,11 +551,18 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); if (!is_grouped_dense_wgrad) { NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, @@ -423,9 +681,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // point to swizzled scale_inv data (store on workspace, only used for GEMM). // Note: even if is_empty_gemm is true, sinv are still non-empty, need to move the pointers auto lhs_sinv_shape_i = - get_mxfp8_scale_shape(lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); + get_block_scale_shape(scaling_mode, lhs_shape_i[0], lhs_shape_i[1], lhs_use_colwise); auto rhs_sinv_shape_i = - get_mxfp8_scale_shape(rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); + get_block_scale_shape(scaling_mode, rhs_shape_i[0], rhs_shape_i[1], rhs_use_colwise); lhs_sinv_size_i = lhs_sinv_shape_i[0] * lhs_sinv_shape_i[1]; rhs_sinv_size_i = rhs_sinv_shape_i[0] * rhs_sinv_shape_i[1]; if (lhs_use_colwise) { @@ -514,6 +772,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t num_non_empty_gemms = lhs_list.size(); +#ifndef USE_ROCM if (is_mxfp8_scaling) { for (int i = 0; i < num_non_empty_gemms; i++) { // The i-th GEMM will use the (i % num_streams)-th stream to compute, @@ -525,6 +784,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type nvte_swizzle_scaling_factors(rhs_swizzle_list[i], rhs_list[i], stream_i); } } +#endif // Launch zero-out kernels before the GEMM calls to use the sync in the multi-stream GEMM size_t num_zero_outs = zero_out_dptr_list.size(); @@ -563,7 +823,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad")); + .Attr("is_grouped_dense_wgrad") + .Attr("use_async_d2h_group_sizes")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/misc.cpp b/transformer_engine/jax/csrc/extensions/misc.cpp index ee81b5ad7..176115ade 100644 --- a/transformer_engine/jax/csrc/extensions/misc.cpp +++ b/transformer_engine/jax/csrc/extensions/misc.cpp @@ -26,11 +26,21 @@ std::vector Shape::to_vector() const { return shape; } -std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise) { - auto block_x = is_colwise ? MXFP8_BLOCK_SIZE.y : MXFP8_BLOCK_SIZE.x; - auto block_y = is_colwise ? MXFP8_BLOCK_SIZE.x : MXFP8_BLOCK_SIZE.y; - auto alignment_x = is_colwise ? MXFP8_ALIGNMENT.y : MXFP8_ALIGNMENT.x; - auto alignment_y = is_colwise ? MXFP8_ALIGNMENT.x : MXFP8_ALIGNMENT.y; +std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, + bool is_colwise) { + auto block_size = BLOCK_SIZE(1, 1); + if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { + block_size = MXFP8_BLOCK_SIZE; + } else if (scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) { + block_size = NVFP4_BLOCK_SIZE; + } else { + NVTE_ERROR("Unsupported scaling_mode = ", static_cast(scaling_mode)); + } + auto block_x = is_colwise ? block_size.y : block_size.x; + auto block_y = is_colwise ? block_size.x : block_size.y; + auto alignment_x = is_colwise ? BLOCK_SCALE_ALIGNMENT.y : BLOCK_SCALE_ALIGNMENT.x; + auto alignment_y = is_colwise ? BLOCK_SCALE_ALIGNMENT.x : BLOCK_SCALE_ALIGNMENT.y; NVTE_CHECK(M % block_x == 0, "M must be divisble by %zu (got %zu)", block_x, M); NVTE_CHECK(N % block_y == 0, "N must be divisble by %zu (got %zu)", block_y, N); diff --git a/transformer_engine/jax/csrc/extensions/misc.h b/transformer_engine/jax/csrc/extensions/misc.h index af7f54feb..c2d3d6f25 100644 --- a/transformer_engine/jax/csrc/extensions/misc.h +++ b/transformer_engine/jax/csrc/extensions/misc.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -34,17 +36,31 @@ inline size_t product(const std::vector &shape) { return ret; } -enum class QuantizeLayout { +enum class JAXX_Quantize_Layout : int64_t { ROWWISE, COLWISE, ROWWISE_COLWISE, }; +inline bool is_quantize_rowwise(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::ROWWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + +inline bool is_quantize_colwise(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::COLWISE || layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + +inline bool is_quantize_2x2x(const JAXX_Quantize_Layout &layout) { + return layout == JAXX_Quantize_Layout::ROWWISE_COLWISE; +} + enum class JAXX_Scaling_Mode : int64_t { NO_SCALING = 0, DELAYED_TENSOR_SCALING = 1, MXFP8_1D_SCALING = 2, CURRENT_TENSOR_SCALING = 3, + NVFP4_1D_SCALING = 4, + NVFP4_2D_SCALING = 5, }; inline bool is_tensor_scaling(const JAXX_Scaling_Mode &mode) { @@ -56,6 +72,11 @@ inline bool is_block_scaling(const JAXX_Scaling_Mode &mode) { return (mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING); } +inline bool is_nvfp4_scaling(const JAXX_Scaling_Mode &mode) { + return (mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING); +} + static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { switch (mode) { case JAXX_Scaling_Mode::NO_SCALING: @@ -70,22 +91,60 @@ static NVTEScalingMode get_nvte_scaling_mode(const JAXX_Scaling_Mode &mode) { case JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING: return NVTEScalingMode::NVTE_DELAYED_TENSOR_SCALING; break; + case JAXX_Scaling_Mode::NVFP4_1D_SCALING: + return NVTEScalingMode::NVTE_NVFP4_1D_SCALING; + break; + case JAXX_Scaling_Mode::NVFP4_2D_SCALING: + // TE common uses the same enum value for 1D and 2D fp4 scaling and instead differentiates them via quant_config.nvfp4_2d_quantization + return NVTEScalingMode::NVTE_NVFP4_1D_SCALING; + break; default: NVTE_ERROR("Invalid Scaling Mode ", static_cast(mode)); break; } } -constexpr struct BlockSize { +struct BLOCK_SIZE { size_t x; size_t y; -} MXFP8_BLOCK_SIZE{1, 32}; -constexpr struct Alignment { - size_t x; - size_t y; -} MXFP8_ALIGNMENT{128, 4}; + constexpr BLOCK_SIZE(int _x, int _y) : x(_x), y(_y) {} +}; -std::vector get_mxfp8_scale_shape(size_t M, size_t N, bool is_colwise); +constexpr BLOCK_SIZE MXFP8_BLOCK_SIZE{1, 32}; +constexpr BLOCK_SIZE NVFP4_BLOCK_SIZE{1, 16}; + +constexpr BLOCK_SIZE BLOCK_SCALE_ALIGNMENT{128, 4}; + +std::vector get_block_scale_shape(JAXX_Scaling_Mode scaling_mode, size_t M, size_t N, + bool is_colwise); + +template +void hash_combine(int64_t &seed, const T &v, Rest... rest) { + seed ^= std::hash{}(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2); + (hash_combine(seed, rest), ...); +} + +enum class JAXX_Collective_Op : int64_t { + NONE = 0, + ALL_GATHER = 1, + REDUCE_SCATTER = 2, +}; + +#ifndef USE_ROCM +static CommOverlapType get_nvte_collective_op(const JAXX_Collective_Op &op) { + switch (op) { + case JAXX_Collective_Op::ALL_GATHER: + return CommOverlapType::AG; + break; + case JAXX_Collective_Op::REDUCE_SCATTER: + return CommOverlapType::RS; + break; + default: + NVTE_ERROR("Invalid Collective Op ", static_cast(op)); + break; + } +} +#endif } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index c35bc6668..b01e23c12 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -29,6 +29,7 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(nullptr, out_dtype, input_shape); + output_tensor.set_amax(nullptr, DType::kFloat32, std::vector{1}); // WAR: NVTE Norms query the is_training from whereas columwise_data is allocated if (is_training && scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { @@ -59,12 +60,13 @@ pybind11::tuple GetNormForwardWorkspaceSizes(size_t batch_size, size_t hidden_si } Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, - Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, - Result_Type colwise_output_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type amax_buf, - Result_Type mu_buf, Result_Type rsigma_buf, Result_Type wkspace_buf, - int norm_type, bool zero_centered_gamma, double epsilon, - int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, bool is_2x) { + Buffer_Type amax_buf, Buffer_Type gamma_buf, Buffer_Type beta_buf, + Result_Type output_buf, Result_Type colwise_output_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, + double epsilon, int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, bool output_amax_when_no_scaling) { auto in_dtype = convert_ffi_datatype_to_te_dtype(x_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto w_dtype = convert_ffi_datatype_to_te_dtype(gamma_buf.element_type()); @@ -77,11 +79,13 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto *output = output_buf->untyped_data(); auto *rsigma = rsigma_buf->untyped_data(); auto *mu = mu_buf->untyped_data(); - auto *amax = reinterpret_cast(amax_buf->untyped_data()); auto *workspace = wkspace_buf->untyped_data(); + auto *amax = reinterpret_cast(amax_buf.untyped_data()); + auto *updated_amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(amax == updated_amax && amax != nullptr, "amax and updated_amax should be aliased"); + auto _norm_type = static_cast(norm_type); - auto _is_2x = static_cast(is_2x); auto x_size = product(x_buf.dimensions()); auto gamma_size = product(gamma_buf.dimensions()); @@ -106,6 +110,10 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), input_shape); + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || + (scaling_mode == JAXX_Scaling_Mode::NO_SCALING && output_amax_when_no_scaling)) { + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } NVTE_CHECK( scaling_mode != JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING, @@ -123,11 +131,9 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - nvte_memset(amax, 0, sizeof(float), stream); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); } - if (_is_2x) { + if (is_quantize_2x2x(quantize_layout)) { output_tensor.set_columnwise_data(colwise_output_buf->untyped_data(), static_cast(out_dtype), input_shape); output_tensor.set_columnwise_scale_inv( @@ -162,13 +168,14 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Ctx() // stream .Arg() // x .Arg() // scale + .Arg() // amax .Arg() // gamma .Arg() // beta .Ret() // output .Ret() // colwise_output .Ret() // scale_inv .Ret() // colwise_scale_inv - .Ret() // amax + .Ret() // updated_amax .Ret() // mu .Ret() // rsigma .Ret() // wkspace @@ -177,9 +184,49 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardHandler, NormForwardFFI, .Attr("epsilon") .Attr("sm_margin") .Attr("scaling_mode") - .Attr("is_2x"), + .Attr("quantize_layout") + .Attr("output_amax_when_no_scaling"), FFI_CudaGraph_Traits); +Error_Type NormForwardInitializeFFI( + cudaStream_t stream, Buffer_Type x_buf, Buffer_Type scale_buf, Buffer_Type amax_buf, + Buffer_Type gamma_buf, Buffer_Type beta_buf, Result_Type output_buf, + Result_Type colwise_output_buf, Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type mu_buf, Result_Type rsigma_buf, + Result_Type wkspace_buf, int norm_type, bool zero_centered_gamma, double epsilon, + int64_t sm_margin, JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, + bool output_amax_when_no_scaling) { + return wrapInStreamCapture(std::function(NormForwardFFI), stream, x_buf, scale_buf, amax_buf, + gamma_buf, beta_buf, output_buf, colwise_output_buf, scale_inv_buf, + colwise_scale_inv_buf, updated_amax_buf, mu_buf, rsigma_buf, + wkspace_buf, norm_type, zero_centered_gamma, epsilon, sm_margin, + scaling_mode, quantize_layout, output_amax_when_no_scaling); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormForwardInitializeHandler, NormForwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // x + .Arg() // scale + .Arg() // amax + .Arg() // gamma + .Arg() // beta + .Ret() // output + .Ret() // colwise_output + .Ret() // scale_inv + .Ret() // colwise_scale_inv + .Ret() // updated_amax + .Ret() // mu + .Ret() // rsigma + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("epsilon") + .Attr("sm_margin") + .Attr("scaling_mode") + .Attr("quantize_layout") + .Attr("output_amax_when_no_scaling")); + pybind11::tuple GetNormBackwardWorkspaceSizes(size_t batch_size, size_t hidden_size, DType in_dtype, DType w_dtype, NVTE_Norm_Type norm_type, bool zero_centered_gamma, int sm_margin) { @@ -305,5 +352,32 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardHandler, NormBackwardFFI, .Attr("sm_margin"), FFI_CudaGraph_Traits); +Error_Type NormBackwardInitializeFFI(cudaStream_t stream, Buffer_Type dz_buf, Buffer_Type x_buf, + Buffer_Type mu_buf, Buffer_Type rsigma_buf, + Buffer_Type gamma_buf, Result_Type xgrad_buf, + Result_Type wgrad_buf, Result_Type dbeta_buf, + Result_Type wkspace_buf, int64_t norm_type, + bool zero_centered_gamma, int64_t sm_margin) { + return wrapInStreamCapture(std::function(NormBackwardFFI), stream, dz_buf, x_buf, mu_buf, + rsigma_buf, gamma_buf, xgrad_buf, wgrad_buf, dbeta_buf, wkspace_buf, + norm_type, zero_centered_gamma, sm_margin); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(NormBackwardInitializeHandler, NormBackwardInitializeFFI, + FFI::Bind() + .Ctx() // stream + .Arg() // dz + .Arg() // x + .Arg() // mu + .Arg() // rsigma + .Arg() // gamma + .Ret() // xgrad + .Ret() // wgrad + .Ret() // dbeta + .Ret() // wkspace + .Attr("norm_type") + .Attr("zero_centered_gamma") + .Attr("sm_margin")); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 563675988..937dde228 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -1,12 +1,16 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "../extensions.h" +#ifndef USE_ROCM +#include "cgemm_helper.h" +#endif //#ifndef USE_ROCM +#include "common/util/cuda_runtime.h" namespace transformer_engine { namespace jax { @@ -22,8 +26,12 @@ pybind11::dict Registrations() { pybind11::dict dict; // Activation - dict["te_act_lu_ffi"] = EncapsulateFFI(ActLuHandler); - dict["te_dact_dbias_quantize_ffi"] = EncapsulateFFI(DActLuDBiasQuantizeHandler); + dict["te_act_lu_ffi"] = + pybind11::dict(pybind11::arg("initialize") = EncapsulateFFI(ActLuInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(ActLuHandler)); + dict["te_dact_dbias_quantize_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(DActLuDBiasQuantizeInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(DActLuDBiasQuantizeHandler)); // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); @@ -45,9 +53,11 @@ pybind11::dict Registrations() { // Normalization dict["te_norm_forward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormForwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormForwardHandler)); dict["te_norm_backward_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CudnnHandleInitHandler), + pybind11::arg("initialize") = EncapsulateFFI(NormBackwardInitializeHandler), pybind11::arg("execute") = EncapsulateFFI(NormBackwardHandler)); // Attention @@ -60,13 +70,20 @@ pybind11::dict Registrations() { // GEMM dict["te_gemm_ffi"] = - pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CollectiveGemmInitHandler), pybind11::arg("execute") = EncapsulateFFI(GemmHandler)); // Grouped GEMM + dict["te_grouped_gemm_d2h_group_sizes_ffi"] = + pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), + pybind11::arg("execute") = EncapsulateFFI(GroupedGemmD2HGroupSizesHandler)); dict["te_grouped_gemm_ffi"] = pybind11::dict(pybind11::arg("prepare") = EncapsulateFFI(CublasHandleInitHandler), pybind11::arg("execute") = EncapsulateFFI(GroupedGemmHandler)); + // Amax + dict["te_rht_amax_ffi"] = pybind11::dict( + pybind11::arg("initialize") = EncapsulateFFI(RHTAmaxCalculationInitializeHandler), + pybind11::arg("execute") = EncapsulateFFI(RHTAmaxCalculationHandler)); #else // Normalization dict["te_norm_forward_ffi"] = EncapsulateFFI(NormForwardHandler); @@ -102,6 +119,10 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); m.def("nvte_get_qkv_format", &nvte_get_qkv_format); m.def("is_non_nt_fp8_gemm_supported", &nvte_is_non_tn_fp8_gemm_supported); +#ifndef USE_ROCM + m.def("initialize_cgemm_communicator", &InitializeCgemmCommunicator); + m.def("get_cgemm_num_max_streams", &GetCgemmNumMaxStreams); +#endif pybind11::enum_(m, "DType", pybind11::module_local()) .value("kByte", DType::kByte) @@ -111,7 +132,9 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("kFloat16", DType::kFloat16) .value("kBFloat16", DType::kBFloat16) .value("kFloat8E4M3", DType::kFloat8E4M3) - .value("kFloat8E5M2", DType::kFloat8E5M2); + .value("kFloat8E5M2", DType::kFloat8E5M2) + .value("kFloat8E8M0", DType::kFloat8E8M0) + .value("kFloat4E2M1", DType::kFloat4E2M1); pybind11::enum_(m, "NVTE_Bias_Type", pybind11::module_local()) .value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) @@ -151,6 +174,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("QGEGLU", NVTE_Activation_Type::QGEGLU) .value("SRELU", NVTE_Activation_Type::SRELU) .value("SREGLU", NVTE_Activation_Type::SREGLU) + .value("CLAMPED_SWIGLU", NVTE_Activation_Type::CLAMPED_SWIGLU) .export_values(); #ifndef USE_ROCM @@ -176,13 +200,20 @@ PYBIND11_MODULE(transformer_engine_jax, m) { .value("DELAYED_TENSOR_SCALING", JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) .value("MXFP8_1D_SCALING", JAXX_Scaling_Mode::MXFP8_1D_SCALING) .value("CURRENT_TENSOR_SCALING", JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) + .value("NVFP4_1D_SCALING", JAXX_Scaling_Mode::NVFP4_1D_SCALING) + .value("NVFP4_2D_SCALING", JAXX_Scaling_Mode::NVFP4_2D_SCALING) + .export_values(); + + pybind11::enum_(m, "JAXX_Quantize_Layout", pybind11::module_local()) + .value("ROWWISE", JAXX_Quantize_Layout::ROWWISE) + .value("COLWISE", JAXX_Quantize_Layout::COLWISE) + .value("ROWWISE_COLWISE", JAXX_Quantize_Layout::ROWWISE_COLWISE) .export_values(); - pybind11::enum_(m, "QuantizeLayout", - pybind11::module_local()) - .value("ROWWISE", transformer_engine::jax::QuantizeLayout::ROWWISE) - .value("COLWISE", transformer_engine::jax::QuantizeLayout::COLWISE) - .value("ROWWISE_COLWISE", transformer_engine::jax::QuantizeLayout::ROWWISE_COLWISE) + pybind11::enum_(m, "JAXX_Collective_Op", pybind11::module_local()) + .value("NONE", JAXX_Collective_Op::NONE) + .value("ALL_GATHER", JAXX_Collective_Op::ALL_GATHER) + .value("REDUCE_SCATTER", JAXX_Collective_Op::REDUCE_SCATTER) .export_values(); } diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index d17d83ec1..9a0a87d69 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -5,8 +5,11 @@ ************************************************************************/ #include +#include + #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" #include "xla/ffi/api/c_api.h" @@ -15,9 +18,9 @@ namespace transformer_engine { namespace jax { pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, - DType in_dtype, DType out_dtype, + DType in_dtype, DType out_dtype, DType scale_dtype, JAXX_Scaling_Mode scaling_mode, - QuantizeLayout q_layout) { + JAXX_Quantize_Layout q_layout) { auto input_shape = std::vector{batch_size, hidden_size}; auto output_shape = std::vector{batch_size, hidden_size}; auto output_trans_shape = std::vector{hidden_size, batch_size}; @@ -30,32 +33,41 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ // this function. We pass a dummy pointer as a workaround. int temp = 0; + bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; + auto input_tensor = TensorWrapper(reinterpret_cast(&temp), input_shape, in_dtype); auto dbias_tensor = TensorWrapper(reinterpret_cast(&temp), dbias_shape, in_dtype); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); + auto scale_shape = std::vector{1}; // Only the pointers will be checked for scale_inv, thus the shapes do not matter - if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::ROWWISE) { + if (is_quantize_rowwise(q_layout)) { output_tensor.set_rowwise_data(reinterpret_cast(&temp), out_dtype, output_shape); - if (is_fp8_dtype(out_dtype)) { - output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, - std::vector{1}); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + if (is_nvfp4) + scale_shape = get_block_scale_shape(scaling_mode, batch_size, hidden_size, false); + output_tensor.set_rowwise_scale_inv(reinterpret_cast(&temp), scale_dtype, + scale_shape); } } - if (q_layout == QuantizeLayout::ROWWISE_COLWISE || q_layout == QuantizeLayout::COLWISE) { + if (is_quantize_colwise(q_layout)) { auto &tmp_shape = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(reinterpret_cast(&temp), out_dtype, tmp_shape); // Only the pointers will be checked for scale_inv, thus the shapes do not matter - if (is_fp8_dtype(out_dtype)) { - output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), DType::kFloat32, - std::vector{1}); + if (scaling_mode != JAXX_Scaling_Mode::NO_SCALING) { + if (is_nvfp4) + scale_shape = + get_block_scale_shape(scaling_mode, hidden_size, batch_size, false); //Transpose + output_tensor.set_columnwise_scale_inv(reinterpret_cast(&temp), scale_dtype, + scale_shape); } } - if (is_fp8_dtype(out_dtype) && scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) { + if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4) { output_tensor.set_amax(reinterpret_cast(&temp), DType::kFloat32, std::vector{1}); output_tensor.set_scale(reinterpret_cast(&temp), DType::kFloat32, @@ -72,22 +84,23 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Buffer_Type amax_buf, Result_Type output_buf, - Result_Type output_trans_buf, Result_Type scale_inv_buf, - Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, - bool is_dbias, int64_t flatten_axis) { + Buffer_Type amax_buf, Buffer_Type sr_rng_state, + Buffer_Type post_rht_amax_buf, Buffer_Type rht_matrix_buf, + Result_Type output_buf, Result_Type output_trans_buf, + Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, + Result_Type updated_amax_buf, Result_Type dbias_buf, + Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, + JAXX_Quantize_Layout quantize_layout, bool is_dbias, + int64_t flatten_axis, bool stochastic_rounding, bool use_rht) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type()); auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); - NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for quantization."); + NVTE_CHECK(is_fp8_dtype(out_dtype) || is_fp4_dtype(out_dtype), + "Output datatype must be FP8 or FP4 for quantization."); auto *input = input_buf.untyped_data(); - auto const quantize_layout = static_cast(quantize_layout_enum); - auto *output = output_buf->untyped_data(); auto *output_trans = output_trans_buf->untyped_data(); auto *dbias = dbias_buf->untyped_data(); @@ -112,39 +125,105 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; + bool const is_nvfp4 = scaling_mode == JAXX_Scaling_Mode::NVFP4_1D_SCALING || + scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING; + + NVTE_CHECK(!stochastic_rounding || is_nvfp4, "Stochastic rounding is only supported for NVFP4."); + NVTE_CHECK(!use_rht || is_nvfp4, "RHT is only supported for NVFP4 scaling"); - if (quantize_layout == QuantizeLayout::ROWWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { + if (is_quantize_rowwise(quantize_layout)) { output_tensor.set_rowwise_data(output, out_dtype, output_shape); - if (is_fp8_dtype(out_dtype)) { - if (is_tensor_scaling) { - float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); - NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); - NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); - output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{1}); - } else { - output_tensor.set_rowwise_scale_inv( - scale_inv_buf->untyped_data(), - convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), - std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), - product(scale_inv_buf->dimensions(), flatten_axis, - scale_inv_buf->dimensions().size())}); - } + if (is_tensor_scaling) { + float *scale = reinterpret_cast(scale_buf.untyped_data()); + float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); + NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); + NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); + output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{1}); + } else { + output_tensor.set_rowwise_scale_inv( + scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), + std::vector{product(scale_inv_buf->dimensions(), 0, flatten_axis), + product(scale_inv_buf->dimensions(), flatten_axis, + scale_inv_buf->dimensions().size())}); } } - if (quantize_layout == QuantizeLayout::COLWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE) { - auto &tmp_shape = (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING) - ? output_trans_shape - : output_shape; + if (is_nvfp4) { + float *amax = reinterpret_cast(amax_buf.untyped_data()); + NVTE_CHECK(amax != nullptr, "amax must be provided for NVFP4"); + output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); + } + + QuantizationConfigWrapper quant_config{}; + if (scaling_mode == JAXX_Scaling_Mode::NVFP4_2D_SCALING) { + quant_config.set_nvfp4_2d_quantization(true); + } + + // Stochastic rounding + quant_config.set_stochastic_rounding(stochastic_rounding); + TensorWrapper sr_rng_state_tensor(sr_rng_state.untyped_data(), std::vector{2}, + DType::kInt64); + if (stochastic_rounding) { + NVTE_CHECK(sr_rng_state.size_bytes() == 2 * sizeof(uint64_t), + "rng_state must be of type int64[2]"); + NVTE_CHECK(sr_rng_state.untyped_data() != nullptr, "rng_state must be provided for SR"); + quant_config.set_rng_state(sr_rng_state_tensor.data()); + } + + if (is_quantize_colwise(quantize_layout)) { +#ifndef USE_ROCM + if (is_nvfp4 && use_rht) { + if (is_quantize_2x2x(quantize_layout)) { + // Do regular rowwise quantization without RHT + nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); + } + + TensorWrapper out_transpose(get_nvte_scaling_mode(scaling_mode)); + + // nvte_hadamard_transform_cast_fusion_columnwise expects the colwise data to be populated in the rowwise buffers on TensorWrapper + out_transpose.set_rowwise_data(output_trans, out_dtype, output_trans_shape); + auto const colwise_flatten_axis = output_trans_buf->dimensions().size() - flatten_axis; + out_transpose.set_rowwise_scale_inv( + colwise_scale_inv_buf->untyped_data(), + convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), + std::vector{product(colwise_scale_inv_buf->dimensions(), 0, colwise_flatten_axis), + product(colwise_scale_inv_buf->dimensions(), colwise_flatten_axis, + colwise_scale_inv_buf->dimensions().size())}); + + float *post_rht_amax = reinterpret_cast(post_rht_amax_buf.untyped_data()); + NVTE_CHECK(post_rht_amax != nullptr, "Post-RHT colwise amax must be provided for NVFP4"); + out_transpose.set_amax(post_rht_amax, DType::kFloat32, std::vector{1}); + + bool const eligible_for_rht_cast_fusion = + input_tensor.dtype() == DType::kBFloat16 && m % 64 == 0 && n % 128 == 0; + NVTE_CHECK(eligible_for_rht_cast_fusion, "RHT cast fusion conditions not met"); + + NVTE_CHECK( + convert_ffi_datatype_to_te_dtype(rht_matrix_buf.element_type()) == DType::kBFloat16, + "RHT matrix must be bf16"); + NVTE_CHECK(rht_matrix_buf.dimensions().size() == 2 && rht_matrix_buf.dimensions()[0] == 16 && + rht_matrix_buf.dimensions()[1] == 16, + "RHT matrix must be 16x16"); + TensorWrapper rht_matrix_tensor(rht_matrix_buf.untyped_data(), std::vector{16, 16}, + DType::kBFloat16); + + nvte_hadamard_transform_cast_fusion_columnwise(input_tensor.data(), out_transpose.data(), + rht_matrix_tensor.data(), quant_config, + stream); + + return ffi_with_cuda_error_check(); + } +#endif // #ifndef USE_ROCM + + bool const is_colwise_transposed = + scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || is_nvfp4; + auto &tmp_shape = is_colwise_transposed ? output_trans_shape : output_shape; output_tensor.set_columnwise_data(output_trans, out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling auto &tmp_buf = is_tensor_scaling ? scale_inv_buf : colwise_scale_inv_buf; @@ -154,26 +233,30 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{1}); } else { + auto colwise_flatten_axis = flatten_axis; + if (is_colwise_transposed) { + // convert flatten_axis from N layout to T layout + colwise_flatten_axis = tmp_buf->dimensions().size() - flatten_axis; + } output_tensor.set_columnwise_scale_inv( tmp_buf->untyped_data(), convert_ffi_datatype_to_te_dtype(tmp_buf->element_type()), std::vector{ - product(tmp_buf->dimensions(), 0, flatten_axis), - product(tmp_buf->dimensions(), flatten_axis, tmp_buf->dimensions().size())}); + product(tmp_buf->dimensions(), 0, colwise_flatten_axis), + product(tmp_buf->dimensions(), colwise_flatten_axis, tmp_buf->dimensions().size())}); } } - if (scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING) { - output_tensor.set_amax(nullptr, DType::kFloat32, std::vector{1}); - } - auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype); auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype); if (is_dbias) { + NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NVFP4_2D_SCALING, + "DBias quantization is not supported for NVFP4_2D_SCALING as fused dbias API cannot " + "take quant_config as input."); nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(), workspace_tensor.data(), stream); } else { - nvte_quantize(input_tensor.data(), output_tensor.data(), stream); + nvte_quantize_v2(input_tensor.data(), output_tensor.data(), quant_config, stream); } return ffi_with_cuda_error_check(); } @@ -184,6 +267,9 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Arg() // input .Arg() // scale .Arg() // amax + .Arg() // sr_rng_state + .Arg() // colwise amax + .Arg() // rht matrix .Ret() // output .Ret() // colwise output .Ret() // scale_inv @@ -192,9 +278,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ret() // dbias .Ret() // wkspace .Attr("scaling_mode") - .Attr("q_layout") + .Attr("q_layout") .Attr("is_dbias") - .Attr("flatten_axis"), + .Attr("flatten_axis") + .Attr("stochastic_rounding") + .Attr("use_rht"), FFI_CudaGraph_Traits); Error_Type DequantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf, @@ -232,7 +320,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty Buffer_Type group_sizes, Result_Type outputs, Result_Type colwise_outputs, Result_Type scale_invs, Result_Type colwise_scale_invs, Result_Type amaxs, - JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, + JAXX_Scaling_Mode scaling_mode, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { NVTE_CHECK(scaling_mode != JAXX_Scaling_Mode::NO_SCALING, "Unsupported scaling mode: ", static_cast(scaling_mode)); @@ -245,7 +333,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto group_size_dtype = convert_ffi_datatype_to_te_dtype(group_sizes.element_type()); auto sinv_dtype = convert_ffi_datatype_to_te_dtype(scale_invs->element_type()); auto amax_dtype = convert_ffi_datatype_to_te_dtype(amaxs->element_type()); - auto const quantize_layout = static_cast(quantize_layout_enum); auto *input_ptr = reinterpret_cast(inputs.untyped_data()); auto *scale_ptr = reinterpret_cast(scales.untyped_data()); @@ -255,10 +342,6 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto *colwise_sinv_ptr = reinterpret_cast(colwise_scale_invs->untyped_data()); auto *amax_ptr = reinterpret_cast(amaxs->untyped_data()); - bool has_rowwise = quantize_layout == QuantizeLayout::ROWWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE; - bool has_colwise = quantize_layout == QuantizeLayout::COLWISE || - quantize_layout == QuantizeLayout::ROWWISE_COLWISE; bool is_delayed_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING; bool const is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; @@ -268,8 +351,8 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty size_t output_dtype_bytes = te_dtype_bytes(out_dtype); size_t sinv_dtype_bytes = te_dtype_bytes(sinv_dtype); size_t group_size_dtype_bytes = te_dtype_bytes(group_size_dtype); - size_t colwise_output_dtype_bytes = has_colwise ? output_dtype_bytes : 0; - size_t colwise_sinv_dtype_bytes = has_colwise ? sinv_dtype_bytes : 0; + size_t colwise_output_dtype_bytes = is_quantize_colwise(quantize_layout) ? output_dtype_bytes : 0; + size_t colwise_sinv_dtype_bytes = is_quantize_colwise(quantize_layout) ? sinv_dtype_bytes : 0; size_t scale_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(scale_dtype) : 0; size_t amax_dtype_bytes = is_tensor_scaling ? te_dtype_bytes(amax_dtype) : 0; @@ -332,7 +415,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty auto inp_i = TensorWrapper(static_cast(input_ptr), shape_i, in_dtype); auto out_i = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); - if (has_rowwise) { + if (is_quantize_rowwise(quantize_layout)) { out_i.set_rowwise_data(static_cast(output_ptr), out_dtype, shape_i); if (is_fp8_dtype(out_dtype)) { @@ -344,14 +427,14 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty sinv_size = 1; } else { const bool is_colwise = false; - auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise); out_i.set_rowwise_scale_inv(static_cast(sinv_ptr), sinv_dtype, sinv_shape_i); sinv_size = product(sinv_shape_i); } } } - if (has_colwise) { + if (is_quantize_colwise(quantize_layout)) { auto &tmp_shape = is_tensor_scaling ? shape_trans_i : shape_i; out_i.set_columnwise_data(static_cast(colwise_output_ptr), out_dtype, tmp_shape); // For 2x delayed scaling, the scale buffer is shared between rowwise and columnwise scaling @@ -363,7 +446,7 @@ Error_Type GroupedQuantizeFFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Ty colwise_sinv_size = 1; } else { const bool is_colwise = true; - auto sinv_shape_i = get_mxfp8_scale_shape(m_i, n, is_colwise); + auto sinv_shape_i = get_block_scale_shape(scaling_mode, m_i, n, is_colwise); out_i.set_columnwise_scale_inv(static_cast(colwise_sinv_ptr), sinv_dtype, sinv_shape_i); colwise_sinv_size = product(sinv_shape_i); @@ -410,7 +493,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // scale_inv colwise .Ret() // amax .Attr("scaling_mode") - .Attr("q_layout") + .Attr("q_layout") .Attr("flatten_axis")); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8087159a3..613455b6c 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -11,12 +11,15 @@ from typing import Tuple, Sequence from functools import partial +import warnings import jax import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.amax import AmaxScope from .quantize import ( ScaledTensorFactory, + ScaledTensor, ScalingMode, QuantizeLayout, QuantizerSet, @@ -24,7 +27,6 @@ with_sharding_constraint_by_logical_axes, is_fp8_gemm_with_all_layouts_supported, TensorUsage, - get_quantize_config, ) @@ -61,8 +63,11 @@ def dense( kernel: jnp.ndarray, bias: jnp.ndarray = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), + transpose_batch_sequence: bool = False, input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, + output_axes: Tuple[str, ...] = None, + collective_op_set: tex.CollectiveOpSet = tex.noop_collective_op_set, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -76,12 +81,20 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract + transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor. + input_axes: Logical axes for sharding the activation input + kernel_axes: Logical axes for sharding the weight matrix + output_axes: Logical axes for sharding the output + collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ - if not get_quantize_config().is_fp8_enabled(): + if transpose_batch_sequence: + warnings.warn("transpose_batch_sequence is not well tested, use with caution!") + + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) @@ -90,29 +103,28 @@ def dense( kernel, bias, contracting_dims, + transpose_batch_sequence, input_axes, kernel_axes, + output_axes, + collective_op_set, quantizer_set, ) return output -@partial( - jax.custom_vjp, - nondiff_argnums=( - 3, - 4, - 5, - ), -) +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7, 8)) def _dense( x, kernel, bias, contracting_dims, + transpose_batch_sequence, input_axes, kernel_axes, - quantizer_set, + output_axes, + collective_op_set, + quantizer_set, # need to be a diff_arg for DelayedScaling state management ): """Internal implementation of dense layer transformation with custom VJP. @@ -124,8 +136,11 @@ def _dense( kernel: Weight matrix bias: Optional bias tensor contracting_dims: Contracting dimensions specification + transpose_batch_sequence: Transpose the batch and sequence dimensions of the input tensor. input_axes: Logical axes for sharding the activation input + output_axes: Logical axes for sharding the output_axes kernel_axes: Logical axes for sharding the weight matrix + collective_op_set: A set of CollectiveOp objects for forward and backward passes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -136,8 +151,11 @@ def _dense( kernel, bias, contracting_dims, + transpose_batch_sequence, input_axes, kernel_axes, + output_axes, + collective_op_set, quantizer_set, ) return output @@ -148,8 +166,11 @@ def _dense_fwd_rule( kernel, bias, contracting_dims, + transpose_batch_sequence, input_axes, kernel_axes, + output_axes, + collective_op_set, quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -175,6 +196,8 @@ def _dense_fwd_rule( x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -182,6 +205,7 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -191,17 +215,20 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set.forward, ) + output = with_sharding_constraint_by_logical_axes(output, output_axes) if use_bias and tex.gemm_uses_jax_dot(): bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_x.get_tensor(usage=TensorUsage.LHS_TRANS), - casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS), + casted_x.get_tensor(usage=TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x), + casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel), x.shape, kernel.shape, use_bias, @@ -212,8 +239,15 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, ctx, grad -): # pylint: disable=unused-argument + contracting_dims, + transpose_batch_sequence, + input_axes, + kernel_axes, + output_axes, + collective_op_set, + ctx, + grad, +): """Backward pass rule for dense layer transformation. Returns: @@ -228,6 +262,7 @@ def _dense_bwd_rule( quantizer_set, flatten_axis_k, ) = ctx + grad = with_sharding_constraint_by_logical_axes(grad, output_axes) fwd_x_contracting_dims, fwd_k_contracting_dims = map( tex.sanitize_dims, (casted_x_lhs.ndim, casted_kernel_rhs.ndim), contracting_dims @@ -238,6 +273,8 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # GEMM NT @@ -254,8 +291,9 @@ def _dense_bwd_rule( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op_set.backward, ) - dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) # GEMM TN # x_non_contracting_dims @@ -267,7 +305,10 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) + + dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) return dgrad, wgrad, dbias, quantizer_set @@ -488,8 +529,12 @@ def _grouped_dense_fwd_rule( ctx = ( group_sizes, - ctx_x, - ctx_kernel, + ctx_x.checkpoint(quantizer_set.x) if isinstance(ctx_x, ScaledTensor) else ctx_x, + ( + ctx_kernel.checkpoint(quantizer_set.kernel) + if isinstance(ctx_kernel, ScaledTensor) + else ctx_kernel + ), x.shape, kernel.shape, use_bias, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index c548c54ef..b5f159022 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -6,7 +6,7 @@ """ from functools import reduce import operator -from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType +from typing import Any, Callable, Iterable, List, Sequence, Tuple, Union, NewType, Optional import numpy as np import jax.numpy as jnp @@ -15,7 +15,6 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from transformer_engine.common import recipe from ..dense import dense @@ -34,11 +33,11 @@ ) from ..quantize import ( QuantizerFactory, - get_quantize_config, - QuantizeMeta, + get_global_quantize_recipe, QuantizeMetaSet, - ScalingMode, TensorSource, + get_quantize_config_with_recipe, + noop_quantizer_set, ) PRNGKey = Any @@ -347,46 +346,44 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method """ def generate_quantizer_set( - self, postfix: str = "", variable_collection: str = None, fp8_recipe=None + self, + postfix: str = "", + variable_collection: str = None, + quantization_checkpoint_name: Optional[str] = None, + fp8_recipe=None, ): """ Generate a set of FP8 meta for a GEMM. """ - def generate_quantize_meta(quantizer_name: str): - collection_name = ( - variable_collection - if variable_collection is not None - else get_quantize_config().COLLECTION_NAME - ) - scale = self.variable( - collection_name, - f"{quantizer_name}{postfix}_scale", - jnp.ones, - (1,), - jnp.float32, - ).value - amax_history = self.variable( - collection_name, - f"{quantizer_name}{postfix}_amax_history", - jnp.zeros, - (get_quantize_config().AMAX_HISTORY_LEN,), - jnp.float32, - ).value - return QuantizeMeta(scale=scale, amax_history=amax_history) - - if get_quantize_config().get_scaling_mode( - TensorSource.X - ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling): - x_meta = generate_quantize_meta("x") - kernel_meta = generate_quantize_meta("kernel") - grad_meta = generate_quantize_meta("grad") - quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) - kwargs = {"quantize_meta_set": quantize_meta_set} - else: - kwargs = {} + if fp8_recipe is None: + fp8_recipe = get_global_quantize_recipe() + + quantize_config = get_quantize_config_with_recipe(fp8_recipe) + + collection_name = ( + variable_collection + if variable_collection is not None + else quantize_config.COLLECTION_NAME + ) + + x_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.X, "x" + ) + kernel_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.KERNEL, "kernel" + ) + grad_meta = quantize_config.get_quantize_flax_meta( + self, collection_name, postfix, TensorSource.DGRAD, "grad" + ) + + quantize_meta_set = QuantizeMetaSet(x=x_meta, kernel=kernel_meta, grad=grad_meta) - quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) + quantizer_set = QuantizerFactory.create_set( + fp8_recipe=fp8_recipe, + quantize_meta_set=quantize_meta_set, + checkpoint_name=quantization_checkpoint_name, + ) return quantizer_set @@ -432,6 +429,10 @@ class DenseGeneral(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. + quantization_checkpoint_name: Optional[str], default = None + The name for checkpointing quantizations. """ features: Union[Iterable[int], int] @@ -446,6 +447,8 @@ class DenseGeneral(TransformerEngineBase): axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 input_axes: Tuple[str, ...] = () + transpose_batch_sequence: bool = False + quantization_checkpoint_name: Optional[str] = None def __post_init__(self): if self.kernel_init is None: @@ -490,7 +493,11 @@ def __call__(self, inputs: Array) -> Array: self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + quantizer_set = self.generate_quantizer_set( + quantization_checkpoint_name=self.quantization_checkpoint_name + ) + + if quantizer_set == noop_quantizer_set: kernel = kernel.astype(input_dtype) if self.use_bias: @@ -503,7 +510,6 @@ def __call__(self, inputs: Array) -> Array: else: bias = None - quantizer_set = self.generate_quantizer_set() contract_ind = tuple(range(0, len(axis))) y = dense( inputs, @@ -512,6 +518,7 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -603,7 +610,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): bias_axes: Tuple[str, ...], default = () The name of axes used to shard bias with a corresponding mesh, only used when :attr:`use_bias=True`. - return_layernorm_output: bool, default = True + return_layernorm_output: bool, default = False Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs. enable_low_rank_adaptation: bool, default = False @@ -632,6 +639,10 @@ class LayerNormDenseGeneral(TransformerEngineBase): depth_scaling: float, default = None The factor to scale the output from `DenseGeneral`. It should be a float value or None. When None is set, then no scaling is applied. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. + quantization_checkpoint_name: Optional[str], default = None + The name for checkpointing quantizations. """ features: Union[Iterable[int], int] @@ -648,7 +659,7 @@ class LayerNormDenseGeneral(TransformerEngineBase): use_bias: bool = False bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = () - return_layernorm_output: bool = True + return_layernorm_output: bool = False enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 low_rank_adaptation_alpha: float = None @@ -657,6 +668,8 @@ class LayerNormDenseGeneral(TransformerEngineBase): layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None + transpose_batch_sequence: bool = False + quantization_checkpoint_name: Optional[str] = None def __post_init__(self): if self.kernel_init is None: @@ -696,10 +709,12 @@ def __call__(self, inputs: Array) -> Array: input_dtype = inputs.dtype ln_output = None - quantizer_set = self.generate_quantizer_set() + quantizer_set = self.generate_quantizer_set( + quantization_checkpoint_name=self.quantization_checkpoint_name + ) fuse_layernorm = ( - get_quantize_config().is_fp8_enabled() + quantizer_set != noop_quantizer_set and not self.return_layernorm_output and self.enable_layernorm ) @@ -750,7 +765,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape, self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + if quantizer_set == noop_quantizer_set: kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -768,6 +783,7 @@ def __call__(self, inputs: Array) -> Array: dot_input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_input_axes) @@ -775,6 +791,7 @@ def __call__(self, inputs: Array) -> Array: y, kernel, contracting_dims=(axis, contract_ind), + transpose_batch_sequence=self.transpose_batch_sequence, input_axes=self.dot_input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, @@ -892,15 +909,19 @@ class LayerNormMLP(TransformerEngineBase): The name of axes used to shard bias with a corresponding mesh for the weight of the second dense layer transformation. Only used when :attr:`use_bias=True`. - return_layernorm_output: bool, default = True + return_layernorm_output: bool, default = False Indicate whether to return the output of layer normalization. If set False, return None as the second tensor in outputs. - activations: Sequence[Union[str, Callable]], default = ('relu',) + activations: Sequence[Union[str, Callable]], default = ('gelu',) The sequence of activation functions to apply after the first dense layer transformation. Each activation has its own transformation layer. + activation_params: dict, default = None + The parameters needed(if any) by the activation functions specified in :attr:`activations`. + At the moment only ('clamped_silu', 'clamped_linear') which is clamped_swiglu used in GPT OSS + need additional parameters. intermediate_dropout_rng_name: str, default = 'dropout' The key in given RNGs via flax.linen.Module.apply that for generating Dropout masks. - intermediate_dropout_rate: float, default = 0.1 + intermediate_dropout_rate: float, default = 0.0 Dropout probability for the dropout op after the :attr:`activations`. intermediate_hidden_dropout_dims: Sequence[int], default = () Dimensions that will share the same dropout mask for hidden @@ -936,6 +957,10 @@ class LayerNormMLP(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. + transpose_batch_sequence: bool, default = False + Indicate whether to transpose the batch and sequence dimensions of the input tensor. + quantization_checkpoint_name: Optional[str], default = None + The name for checkpointing quantizations. """ intermediate_dim: int = 2048 @@ -954,10 +979,11 @@ class LayerNormMLP(TransformerEngineBase): bias_init: Initializer = nn.initializers.zeros bias_axes_1: Tuple[str, ...] = ("act", "mlp") bias_axes_2: Tuple[str, ...] = ("embed",) - return_layernorm_output: bool = True - activations: Sequence[Union[str, Callable]] = ("relu",) + return_layernorm_output: bool = False + activations: Sequence[Union[str, Callable]] = ("gelu",) + activation_params: dict = None intermediate_dropout_rng_name: str = "dropout" - intermediate_dropout_rate: float = 0.1 + intermediate_dropout_rate: float = 0.0 intermediate_hidden_dropout_dims: Sequence[int] = () enable_low_rank_adaptation: bool = False low_rank_adaptation_dim: int = 32 @@ -969,6 +995,8 @@ class LayerNormMLP(TransformerEngineBase): dot_2_input_axes: Tuple[str, ...] = None ffn1_ckpt_name: str = "ffn1" ffn2_ckpt_name: str = "ffn2" + transpose_batch_sequence: bool = False + quantization_checkpoint_name: Optional[str] = None def __post_init__(self): if self.kernel_init is None: @@ -1003,8 +1031,12 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: """ assert self.axis == -1, "Only support axis == -1 at this moment" - ffn1_quantizer_set = self.generate_quantizer_set("_0") - ffn2_quantizer_set = self.generate_quantizer_set("_1") + ffn1_quantizer_set = self.generate_quantizer_set( + "_0", quantization_checkpoint_name=self.quantization_checkpoint_name + ) + ffn2_quantizer_set = self.generate_quantizer_set( + "_1", quantization_checkpoint_name=self.quantization_checkpoint_name + ) input_dtype = inputs.dtype ln_output = None @@ -1012,7 +1044,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: # TODO(Phuong): use fuse_layernorm for high-precision # when NoOpQuantizer and Tensor are implemented fuse_layernorm = ( - get_quantize_config().is_fp8_enabled() + ffn1_quantizer_set != noop_quantizer_set and not self.return_layernorm_output and self.enable_layernorm ) @@ -1023,6 +1055,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: ("relu", "linear"), ("quick_gelu", "linear"), ("squared_relu", "linear"), + ("clamped_silu", "clamped_linear"), ] act_pool = [("gelu",), ("silu",), ("relu",), ("quick_gelu",), ("squared_relu",)] normalized_acts = [] @@ -1031,7 +1064,9 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: return False normalized_acts.append(act.lower()) normalized_acts = tuple( - reversed(normalized_acts) if normalized_acts[0] == "linear" else normalized_acts + reversed(normalized_acts) + if (normalized_acts[0] == "linear" or normalized_acts[0] == "clamped_linear") + else normalized_acts ) is_act_implemented = normalized_acts in (gated_act_pool + act_pool) @@ -1095,7 +1130,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + if ffn1_quantizer_set == noop_quantizer_set: kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] @@ -1107,7 +1142,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_2_shape, self.dtype, ) - if not get_quantize_config().is_fp8_enabled(): + if ffn2_quantizer_set == noop_quantizer_set: kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -1150,7 +1185,9 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name=self.ffn1_ckpt_name, ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, + activation_params=self.activation_params, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), + transpose_batch_sequence=self.transpose_batch_sequence, ) out = out.reshape(*inputs.shape[: self.axis], *hidden_size_tuple) @@ -1169,6 +1206,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) else: y = with_sharding_constraint_by_logical_axes(y, self.dot_1_input_axes) @@ -1179,6 +1217,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_1_input_axes, kernel_axes=self.kernel_axes_1, quantizer_set=ffn1_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -1251,6 +1290,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): input_axes=self.dot_2_input_axes, kernel_axes=self.kernel_axes_2, quantizer_set=ffn2_quantizer_set, + transpose_batch_sequence=self.transpose_batch_sequence, ) if self.enable_low_rank_adaptation: @@ -1287,4 +1327,4 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype - return out, ln_output # Output, layner_norm_output + return out, ln_output # Output, layer_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index fb3ac7b9a..d096e7997 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -53,6 +53,7 @@ def _generate_drop_path_shape(shape: Sequence[int], batch_dim: int) -> Sequence[ return drop_path_shape +# TODO(Phuong): move this function to sharding.py def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: """ Extend the given Flax logical axis rules with the predefined TransformerLayer's @@ -65,7 +66,7 @@ def extend_logical_axis_rules(rules: LogicalRules) -> LogicalRules: for 1D-sharding tensor parallelism. .. warning:: - Please make sure ShardingResource is set via fp8_autocast before calling this function. + Please make sure ShardingResource is set via autocast before calling this function. .. note:: This function is only needed when using TransformerLayer. For other modules, such as @@ -196,6 +197,7 @@ def __call__( fused_scale_factor = scale_factor if self.attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: attn_weights += bias + bias = None def apply_swa_mask(original_mask: Array) -> Array: """Apply the sliding window mask to a given mask""" @@ -405,10 +407,10 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods Users can select between these two backends via the :attr:`NVTE_FUSED_ATTN` environment variable: - * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention (default). - * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention. If the required cuDNN fused attention - kernel is not available on the system, a warning will be issued, and the module will - automatically fall back to the unfused backend. + * Set :attr:`NVTE_FUSED_ATTN=0` for unfused attention. + * Set :attr:`NVTE_FUSED_ATTN=1` for fused attention (default). If the required cuDNN fused + attention kernel is not available on the system, a warning will be issued, and the module + will automatically fall back to the unfused backend. .. note:: The DotProductAttention default setting enables non-deterministic kernels for reduced @@ -600,7 +602,8 @@ def __call__( else: assert bias is not None - enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "0")) + # Use fused attn (if kernel check below passes) by default + enable_fused_attn = int(os.getenv("NVTE_FUSED_ATTN", "1")) sequence_dim = 0 if self.transpose_batch_sequence else 1 seqlen_q = query.shape[sequence_dim] @@ -1206,6 +1209,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="qkv", dtype=self.dtype, )(inputs_q) @@ -1233,6 +1237,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="query", )(inputs_q) @@ -1251,6 +1256,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): enable_low_rank_adaptation=lora_scope.qkv_proj, low_rank_adaptation_dim=self.low_rank_adaptation_dim, low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, + transpose_batch_sequence=self.transpose_batch_sequence, name="kv", dtype=self.dtype, )(inputs_kv) @@ -1291,6 +1297,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kernel_init=query_init, layernorm_input_axes=inputs_logical_axes_maybe_sp, dot_input_axes=inputs_logical_axes_no_sp, + transpose_batch_sequence=self.transpose_batch_sequence, name="query", )(inputs_q) @@ -1613,7 +1620,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods Dimensions that will share the same dropout mask for hidden attention_dropout: float, default = 0.1 Dropout probability for the dropout op during multi-head attention. - intermediate_dropout: float, default = 0.1 + intermediate_dropout: float, default = 0.0 Dropout probability for the dropout op after FC1 layer. intermediate_dropout_dims: Sequence[int], default = () Dimensions that will share the same dropout mask for hidden after FC1 layer. @@ -1628,9 +1635,12 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods flax.linen.initializers.variance_scaling(1.0, 'fan_in', 'truncated_normal') Used for initializing weights of FC1 and FC2 layers. It should be a callable object with three arguments (jax.random.PRNGKey, shape, dtype). - mlp_activations: Sequence[str], default = ('relu', ) + mlp_activations: Sequence[str], default = ('gelu', ) The sequence of activation functions to apply after the first linear transformation. Each activation has its own transformation layer. + mlp_activation_params: dict = None + This is only used when ('clamped_silu', 'clamped_linear') is in :attr:`mlp_activations`. At the moment + ClampedSwiglu is the only activation that requires parameters. use_bias: bool, default = False Indicate whether to enable bias shifting for QKVO projections, FC1 and FC2. If set to False, the layer will not learn additive biases. @@ -1745,12 +1755,13 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods hidden_dropout: float = 0.1 hidden_dropout_dims: Sequence[int] = () attention_dropout: float = 0.1 - intermediate_dropout: float = 0.1 + intermediate_dropout: float = 0.0 intermediate_dropout_dims: Sequence[int] = () dropout_rng_name: str = "dropout" mha_kernel_init: Initializer = None mlp_kernel_init: Initializer = None - mlp_activations: Sequence[str] = ("relu",) + mlp_activations: Sequence[str] = ("gelu",) + mlp_activation_params: dict = None use_bias: bool = False bias_init: Initializer = nn.initializers.zeros apply_residual_connection_post_layernorm: bool = False @@ -2045,6 +2056,7 @@ def hidden_dropout(x, deterministic): return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, + activation_params=self.mlp_activation_params, intermediate_dropout_rng_name=self.dropout_rng_name, intermediate_dropout_rate=self.intermediate_dropout, intermediate_hidden_dropout_dims=self.intermediate_dropout_dims, @@ -2064,6 +2076,7 @@ def hidden_dropout(x, deterministic): layernorm_input_axes=(*generate_batch_seqlen_logical_axes(), HIDDEN_AXES), dot_1_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_AXES), dot_2_input_axes=(*generate_batch_seqlen_logical_axes(False), HIDDEN_TP_AXES), + transpose_batch_sequence=self.transpose_batch_sequence, name="mlp", )(mlp_input, deterministic=deterministic) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index fb9783075..14726553f 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -16,13 +16,13 @@ import jax.numpy as jnp from . import cpp_extensions as tex +from .cpp_extensions.amax import AmaxScope from .quantize import ( QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, TensorUsage, - get_quantize_config, ) @@ -35,6 +35,7 @@ def layernorm_dense( norm_type: str = "layernorm", zero_centered_gamma: bool = False, epsilon: float = 1e-6, + transpose_batch_sequence: bool = False, layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, @@ -55,6 +56,7 @@ def layernorm_dense( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix @@ -70,7 +72,7 @@ def layernorm_dense( - Quantization is applied to both the normalized input and kernel """ - if not get_quantize_config().is_fp8_enabled(): + if quantizer_set == noop_quantizer_set: input_dtype = x.dtype kernel = kernel.astype(input_dtype) @@ -83,6 +85,7 @@ def layernorm_dense( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -100,6 +103,7 @@ def layernorm_dense( 8, 9, 10, + 11, ), ) def _layernorm_dense( @@ -111,6 +115,7 @@ def _layernorm_dense( norm_type: str, zero_centered_gamma: bool, epsilon: float, + transpose_batch_sequence: bool, layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], @@ -131,6 +136,7 @@ def _layernorm_dense( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding quantizer_set: Set of quantizers @@ -147,6 +153,7 @@ def _layernorm_dense( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -164,6 +171,7 @@ def _layernorm_dense_fwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -194,6 +202,8 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) @@ -203,6 +213,8 @@ def _layernorm_dense_fwd_rule( kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -213,6 +225,7 @@ def _layernorm_dense_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -222,8 +235,8 @@ def _layernorm_dense_fwd_rule( output += jnp.reshape(bias, bias_new_shape) ctx = ( - casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), - casted_kernel.get_tensor(TensorUsage.RHS_TRANS), + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(quantizer_set.x), + casted_kernel.get_tensor(TensorUsage.RHS_TRANS).checkpoint(quantizer_set.kernel), x.shape, kernel.shape, mu, @@ -245,6 +258,7 @@ def _layernorm_dense_bwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, layernorm_input_axes, dot_input_axes, kernel_axes, @@ -285,6 +299,8 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim @@ -301,6 +317,7 @@ def _layernorm_dense_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -314,6 +331,7 @@ def _layernorm_dense_bwd_rule( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_constracting_dim, g_constracting_dim), + transpose_batch_sequence=transpose_batch_sequence, ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index fc957801a..47fed6c3a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -21,13 +21,13 @@ from jax.ad_checkpoint import checkpoint_name from . import cpp_extensions as tex +from .cpp_extensions.amax import AmaxScope from .layernorm import canonicalize_norm_type from .quantize import ( with_sharding_constraint_by_logical_axes, QuantizerSet, noop_quantizer_set, TensorUsage, - get_quantize_config, ) @@ -40,6 +40,7 @@ def layernorm_mlp( norm_type: str, zero_centered_gamma: bool = False, epsilon: float = 1e-6, + transpose_batch_sequence: bool = False, norm_input_axes: Tuple[str, ...] = None, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, @@ -48,6 +49,11 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), + activation_params: dict = None, + collective_op_sets: Tuple[tex.CollectiveOpSet] = ( + tex.noop_collective_op_set, + tex.noop_collective_op_set, + ), quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -71,6 +77,7 @@ def layernorm_mlp( norm_type: Type of normalization ("layernorm" or "rmsnorm") zero_centered_gamma: Whether to use zero-centered gamma for normalization epsilon: Small constant for numerical stability in normalization + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for sharding the layernorm input dot_1_input_axes: Logical axes for sharding the first matrix multiplication dot_2_input_axes: Logical axes for sharding the second matrix multiplication @@ -79,6 +86,7 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -105,7 +113,7 @@ def layernorm_mlp( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" - if not get_quantize_config().is_fp8_enabled(): + if quantizer_sets == (noop_quantizer_set, noop_quantizer_set): input_dtype = x.dtype kernel_1 = kernel_1.astype(input_dtype) kernel_2 = kernel_2.astype(input_dtype) @@ -121,6 +129,7 @@ def layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -129,12 +138,14 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, + collective_op_sets, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -146,6 +157,7 @@ def _layernorm_mlp( norm_type: str, zero_centered_gamma: bool, epsilon: float, + transpose_batch_sequence: bool, norm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], @@ -154,6 +166,8 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], + activation_params: dict, + collective_op_sets: Tuple[tex.CollectiveOpSet], quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -173,12 +187,16 @@ def _layernorm_mlp( norm_type: Type of normalization zero_centered_gamma: Whether to use zero-centered gamma epsilon: Small constant for numerical stability + transpose_batch_sequence: Whether to transpose the batch and sequence dimensions norm_input_axes: Logical axes for layernorm sharding dot_1_input_axes: Logical axes for first matrix multiplication sharding dot_2_input_axes: Logical axes for second matrix multiplication sharding + kernel_1_axes: Logical axes for first weight matrix sharding + kernel_2_axes: Logical axes for second weight matrix sharding ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) + collective_op_sets: Tuple of two collective gemm config sets for the two dense layer transformations quantizer_sets: Tuple of quantizer sets Returns: @@ -195,6 +213,7 @@ def _layernorm_mlp( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -203,6 +222,8 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, + collective_op_sets, quantizer_sets, ) return output @@ -219,6 +240,7 @@ def _layernorm_mlp_fwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -227,6 +249,8 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, + collective_op_sets, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -246,6 +270,10 @@ def _layernorm_mlp_fwd_rule( del kernel_1_axes, kernel_2_axes ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.forward.is_reduce_scatter + assert not collective_op_set_2.forward.is_all_gather # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (hidden_in, activation_len, intermediate) @@ -272,6 +300,8 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) @@ -279,6 +309,8 @@ def _layernorm_mlp_fwd_rule( kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) # NN GEMM @@ -287,8 +319,10 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_1.forward, ) if use_bias_1 and tex.gemm_uses_jax_dot(): @@ -310,6 +344,13 @@ def _layernorm_mlp_fwd_rule( dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) @@ -317,6 +358,8 @@ def _layernorm_mlp_fwd_rule( casted_kernel_2 = tex.quantize( kernel_2, quantizer=ffn2_quantizer_set.kernel, + amax_scope=AmaxScope.FSDP, + transpose_batch_sequence=transpose_batch_sequence, ) # NN GEMM @@ -325,8 +368,10 @@ def _layernorm_mlp_fwd_rule( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + collective_op=collective_op_set_2.forward, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -334,6 +379,8 @@ def _layernorm_mlp_fwd_rule( bias_2_new_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2_shape dot_2_output += jnp.reshape(bias_2, bias_2_new_shape) + # sharding of outputs should be the same as dot_1's input + dot_2_output = with_sharding_constraint_by_logical_axes(dot_2_output, dot_1_input_axes) dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) ctx = ( @@ -342,11 +389,11 @@ def _layernorm_mlp_fwd_rule( rsigma, gamma, beta, - casted_ln_out.get_tensor(TensorUsage.LHS_TRANS), - casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS), + casted_ln_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn1_quantizer_set.x), + casted_kernel_1.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn1_quantizer_set.kernel), dot_1_output, - casted_act_out.get_tensor(TensorUsage.LHS_TRANS), - casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS), + casted_act_out.get_tensor(TensorUsage.LHS_TRANS).checkpoint(ffn2_quantizer_set.x), + casted_kernel_2.get_tensor(TensorUsage.RHS_TRANS).checkpoint(ffn2_quantizer_set.kernel), x_contracting_dims, k_contracting_dims, kernel_1.shape, @@ -363,6 +410,7 @@ def _layernorm_mlp_bwd_rule( norm_type, zero_centered_gamma, epsilon, + transpose_batch_sequence, norm_input_axes, dot_1_input_axes, dot_2_input_axes, @@ -371,6 +419,8 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, + activation_params, + collective_op_sets, ctx, grad, ): @@ -409,6 +459,10 @@ def _layernorm_mlp_bwd_rule( ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets + collective_op_set_1, collective_op_set_2 = collective_op_sets + + assert not collective_op_set_1.backward.is_all_gather + assert not collective_op_set_2.backward.is_reduce_scatter # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) @@ -417,6 +471,8 @@ def _layernorm_mlp_bwd_rule( grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -434,6 +490,8 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op_set_2.backward, ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -448,6 +506,7 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -457,6 +516,13 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, + act_params=( + tex.activation.ActivationParams.create(activation_type, **activation_params) + if activation_params + else None + ), + amax_scope=AmaxScope.TPSP, + transpose_batch_sequence=transpose_batch_sequence, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -474,6 +540,8 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), + transpose_batch_sequence=transpose_batch_sequence, + collective_op=collective_op_set_1.backward, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -484,6 +552,7 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), + transpose_batch_sequence=transpose_batch_sequence, ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/quantize/__init__.py b/transformer_engine/jax/quantize/__init__.py index 11f692917..878067a78 100644 --- a/transformer_engine/jax/quantize/__init__.py +++ b/transformer_engine/jax/quantize/__init__.py @@ -14,5 +14,7 @@ from .dequantizer import * from .scaling_modes import * from .metadata import * +from .hadamard import * from .helper import * from .device_utils import * +from .misc import * diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 9d46c3c30..80ebc6b87 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -15,6 +15,8 @@ import jax.numpy as jnp from .scaling_modes import ScalingMode +from .hadamard import apply_rht + __all__ = ["ScalingModeToDequantizerMap"] @@ -119,7 +121,7 @@ def _dequantize_func(data, scale_inv, dq_dtype, scaling_mode, is_colwise, flatte 0 < flatten_axis < len(data_shape) ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" scale_shape = scaling_mode.get_scale_shape( - data_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + data_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) data = data.reshape( @@ -161,10 +163,102 @@ def dequantize(scaled_tensor): ) +class NVFP4Dequantizer(Dequantizer): + """NVFP4 Dequantizer Class. + + This class provides static methods for dequantizing tensors that have been + quantized using NVFP4 scaling modes. + """ + + @staticmethod + def _dequantize_func( + data, scale_inv, amax, dq_dtype, scaling_mode, is_colwise, flatten_axis, has_rht_applied + ): + """Dequantize a tensor using block scaling. + + Args: + data: The quantized tensor data + scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor + dq_dtype: The data type for dequantized values + scaling_mode: The scaling mode used for quantization + is_colwise: Whether the scaling is column-wise + flatten_axis: The axis along which the tensor could be flattened to 2D + has_rht_applied: Whether the quantization has RHT applied and we need to apply the inverse RHT to dequantize + + Returns: + The dequantized tensor + """ + + DATA_DTYPE_MAX = jnp.finfo(data.dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(scale_inv.dtype).max.astype(jnp.float32) + tensor_scale_inv = amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) + + data = data.astype(jnp.float32) + scale_inv = scale_inv.astype(jnp.float32) * tensor_scale_inv + data_layout = "T" if is_colwise else "N" + + data_shape = data.shape + flatten_axis = len(data_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + scale_shape = scaling_mode.get_scale_shape( + data_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + # expect the flatten_axis wrt the N layout + flatten_axis=flatten_axis if data_layout == "N" else len(data_shape) - flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + + data = data.reshape( + *data_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(data_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *data_shape[flatten_axis:-1], + scale_shape[-1], + int(data_shape[-1] / scale_shape[-1]), + ) + + scale_inv = jnp.expand_dims(scale_inv, axis=(flatten_axis + 2 - 2, -1)) + out = jnp.asarray(data * scale_inv, dq_dtype).reshape(data_shape) + + # Apply inverse of RHT if needed + if has_rht_applied: + out = apply_rht(out, inverse=True) + + return out + + @staticmethod + def dequantize(scaled_tensor): + """Dequantize a tensor using block scaling. + + Args: + scaled_tensor: The quantized tensor to dequantize + + Returns: + The dequantized tensor + """ + return NVFP4Dequantizer._dequantize_func( + scaled_tensor.data, + scaled_tensor.scale_inv, + scaled_tensor.amax, + scaled_tensor.dq_dtype, + scaled_tensor.scaling_mode, + scaled_tensor.is_colwise, + scaled_tensor.flatten_axis, + scaled_tensor.has_rht_applied, + ) + + ScalingModeToDequantizerMap = { ScalingMode.DELAYED_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.CURRENT_TENSOR_SCALING: TensorScaleDequantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleDequantizer, + ScalingMode.NVFP4_1D_SCALING: NVFP4Dequantizer, + ScalingMode.NVFP4_2D_SCALING: NVFP4Dequantizer, ScalingMode.NO_SCALING: NoopDequantizer, } @@ -210,13 +304,13 @@ def _grouped_dequantize(grouped_scaled_tensor): ) padded_scale_shape_i = scaling_mode.get_scale_shape( data_shape_i, - grouped_scaled_tensor.is_colwise, + is_colwise=grouped_scaled_tensor.is_colwise, is_padded=True, flatten_axis=flatten_axis, ) unpadded_scale_shape_i = scaling_mode.get_scale_shape( data_shape_i, - grouped_scaled_tensor.is_colwise, + is_colwise=grouped_scaled_tensor.is_colwise, is_padded=False, flatten_axis=flatten_axis, ) diff --git a/transformer_engine/jax/quantize/hadamard.py b/transformer_engine/jax/quantize/hadamard.py new file mode 100644 index 000000000..5f6f0ec2b --- /dev/null +++ b/transformer_engine/jax/quantize/hadamard.py @@ -0,0 +1,46 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Randomized Hadamard Transform (RHT) utilities for JAX.""" +import jax.numpy as jnp + + +def get_wgrad_sign_vector() -> list[int]: + """Get a fixed sign vector for the RHT used in NVFP4 weight gradient quantization.""" + return [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1] + + +def get_sign_from_vector(vector: list[int]) -> int: + """Convert a sign vector to a bitmask integer.""" + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask + + +def apply_rht(x: jnp.ndarray, inverse=False) -> jnp.ndarray: + """Apply the Randomized Hadamard Transform (RHT) to the input tensor.""" + h = get_rht_matrix() + block_size = 16 + if inverse: + h = jnp.linalg.inv(h.astype(jnp.float32)).astype(jnp.bfloat16) + # TODO(jberchtold): These reshapes will break partitioning, fixme + return (x.reshape(-1, block_size) @ h).reshape(x.shape) + + +def get_rht_matrix() -> jnp.ndarray: + """Get the Randomized Hadamard Transform (RHT) matrix used in NVFP4 weight gradient quantization. + + Returns: + A (16, 16) bfloat16 matrix representing the RHT. This matrix is pre-multiplied by the random sign mask. + """ + import scipy + + block_size = 16 + h = jnp.array(scipy.linalg.hadamard(block_size)) + + # Apply the random sign mask + s = jnp.array(get_wgrad_sign_vector(), dtype=jnp.int32) + h = jnp.diag(s) @ h + + return (h / jnp.sqrt(block_size)).astype(jnp.bfloat16) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 4037eae80..95d5aea21 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,13 +9,18 @@ This module provides configuration and helper functions for managing quantization metadata in JAX, including support for different scaling modes and datatypes. """ + from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Dict, Union, Sequence, Type -from functools import reduce +import hashlib +from typing import Optional, Tuple, Dict, Union, Sequence, Type, List +from functools import reduce, lru_cache import operator +from importlib.metadata import version as get_pkg_version +import warnings +from packaging.version import Version as PkgVersion import jax import jax.numpy as jnp @@ -29,32 +34,59 @@ get_cublasLt_version, get_cuda_version, ) -from transformer_engine.common import recipe -from transformer_engine.jax.sharding import global_shard_guard, MeshResource - +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, + NVFP4BlockScaling, +) +from transformer_engine.jax.sharding import ( + global_shard_guard, + MeshResource, + get_num_devices_in_mesh, + get_all_mesh_axes, + with_sharding_constraint, +) + +from .metadata import QuantizeMeta from .scaling_modes import ScalingMode -from .. import cpp_extensions as tex from .device_utils import get_device_compute_capability __all__ = [ - "get_quantize_config", + "get_global_quantize_recipe", + "get_quantize_config_with_recipe", + "autocast", "fp8_autocast", "is_fp8_available", + "is_scaling_mode_supported", + "get_supported_scaling_modes", + "get_supported_quantization_recipes", "update_collections", - "get_delayed_scaling", "apply_padding_to_scale_inv", "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", "TensorSource", ] -_is_fp8_available = None -_reason_for_no_fp8 = "" +_is_scaling_mode_supported = None +_reason_for_no_scaling_mode = "" Collection = Union[Dict, FrozenDict] NVTE_FP8_COLLECTION_NAME = "fp8_metas" +@lru_cache(maxsize=None) +def _jax_version_meet_requirement(version: str): + """ + Helper function checking if required JAX version is available + """ + jax_version = PkgVersion(get_pkg_version("jax")) + jax_version_required = PkgVersion(version) + return jax_version >= jax_version_required + + def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -69,8 +101,6 @@ def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: return True, "" else: return False, "Device arch gfx94x or gfx95x required for FP8 execution." - if gpu_arch >= 90: # hopper and above - return True, "" if gpu_arch < 89: # pre-ada return False, "Device compute capability 8.9 or higher required for FP8 execution." if get_cublasLt_version() < 120103: @@ -91,20 +121,33 @@ def _check_block_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """ if is_hip_extension(): return False, "FP8 block scaled gemm not yet supported for ROCm" - if gpu_arch >= 100: # blackwell and above - return True, "" if gpu_arch < 99: # pre-blackwell return False, "Device compute capability 9.9 or higher required for MXFP8 execution." if get_cublasLt_version() < 120800: return False, "CublasLt version 12.8.0 or higher required for MXFP8 execution." - if get_cuda_version() < 12010: + if get_cuda_version() < 12080: return False, "Cuda version 12.8 or higher required for MXFP8 execution." - if not tex.jax_version_meet_requirement("0.5.3"): + if not _jax_version_meet_requirement("0.5.3"): return False, "Jax version 0.5.3 or higher required for MXFP8 execution." return True, "" -def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: +def _check_fp4_support(gpu_arch) -> Tuple[bool, str]: + """Check if FP4 is supported for the given GPU architecture.""" + if is_hip_extension(): + return False, "FP4 not yet supported for ROCm" + if gpu_arch < 100: # pre-blackwell + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + if get_cublasLt_version() < 120800: + return False, "CublasLt version 12.8.0 or higher required for NVFP4 execution." + if get_cuda_version() < 12080: + return False, "Cuda version 12.8 or higher required for NVFP4 execution." + if not _jax_version_meet_requirement("0.5.3"): + return False, "Jax version 0.5.3 or higher required for NVFP4 execution." + return True, "" + + +def _check_scaling_support(scaling_mode: ScalingMode, gpu_id: int) -> Tuple[bool, str]: """Check if FP8 is supported for the given scaling mode and GPU. Args: @@ -117,9 +160,35 @@ def _check_fp8_support(scaling_mode, gpu_id) -> Tuple[bool, str]: gpu_arch = get_device_compute_capability(gpu_id) if scaling_mode.is_tensor_scaling(): return _check_delayed_scaling_fp8_support(gpu_arch) - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + if scaling_mode.is_mxfp8_scaling: return _check_block_scaling_fp8_support(gpu_arch) - return (False, "Unsupported scaling_mode!") + if scaling_mode.is_nvfp4_scaling: + return _check_fp4_support(gpu_arch) + return (True, "") # NO_SCALING is always supported + + +def is_scaling_mode_supported( + scaling_mode=ScalingMode.NO_SCALING, + gpu_id=None, +) -> Tuple[bool, str]: + """Check if the given scaling mode is available for the given GPU.""" + if gpu_id is not None: + return _check_scaling_support(scaling_mode, gpu_id) + + global _is_scaling_mode_supported, _reason_for_no_scaling_mode + if _is_scaling_mode_supported is None: + _is_scaling_mode_supported = {} + _reason_for_no_scaling_mode = {} + if scaling_mode not in _is_scaling_mode_supported: + _is_scaling_mode_supported[scaling_mode] = True + _reason_for_no_scaling_mode[scaling_mode] = "" + for local_gpu_id in range(len(jax.local_devices())): + ret, msg = _check_scaling_support(scaling_mode, local_gpu_id) + if ret is False: + _is_scaling_mode_supported[scaling_mode] = ret + _reason_for_no_scaling_mode[scaling_mode] = msg + return ret, msg + return _is_scaling_mode_supported[scaling_mode], _reason_for_no_scaling_mode[scaling_mode] def is_fp8_available( @@ -135,29 +204,39 @@ def is_fp8_available( Returns: A tuple of (bool, str) indicating availability and any error message """ - if gpu_id is not None: - return _check_fp8_support(scaling_mode, gpu_id) - - global _is_fp8_available, _reason_for_no_fp8 - if _is_fp8_available is None: - _is_fp8_available = {} - _reason_for_no_fp8 = {} - - if scaling_mode not in _is_fp8_available: - _is_fp8_available[scaling_mode] = True - _reason_for_no_fp8[scaling_mode] = "" - # JAX doesn't provide the local GPU id. - for local_gpu_id in range(len(jax.local_devices())): - ret, msg = _check_fp8_support(scaling_mode, local_gpu_id) - if ret is False: - _is_fp8_available[scaling_mode] = ret - _reason_for_no_fp8[scaling_mode] = msg - return ret, msg - - return _is_fp8_available[scaling_mode], _reason_for_no_fp8[scaling_mode] - - -def _format2dtypes(format_: recipe.Format): + warnings.warn( + "is_fp8_available is deprecated. Use is_scaling_mode_supported instead.", DeprecationWarning + ) + return is_scaling_mode_supported(scaling_mode=scaling_mode, gpu_id=gpu_id) + + +# TODO(Phuong): make the infrastruture to support NO_SCALING +def get_supported_scaling_modes() -> List[ScalingMode]: + """Get all supported quantization scaling modes.""" + return [ + scaling_mode + for scaling_mode in ScalingMode + if is_scaling_mode_supported(scaling_mode=scaling_mode)[0] + and scaling_mode != ScalingMode.NO_SCALING + ] + + +def get_supported_quantization_recipes() -> List[Recipe]: + """Get all supported quantization recipes.""" + # We don't support all the recipes TE/Common supports yet + # return [get_quantize_config_class(recipe)() for recipe in recipe.Recipe.__subclasses__()] + all_recipes = [ + DelayedScaling(), + Float8CurrentScaling(), + MXFP8BlockScaling(), + NVFP4BlockScaling(), + ] + return [ + recipe for recipe in all_recipes if get_quantize_config_class(recipe)().is_supported()[0] + ] + + +def _format2dtypes(format_: Format): """Convert recipe.Format.dtype to corresponding JAX dtypes. Args: @@ -166,12 +245,14 @@ def _format2dtypes(format_: recipe.Format): Returns: A tuple of (forward_dtype, backward_dtype) for the given format """ - if format_ == recipe.Format.E4M3: + if format_ == Format.E4M3: return get_jnp_float8_e4m3_type(), get_jnp_float8_e4m3_type() - if format_ == recipe.Format.E5M2: + if format_ == Format.E5M2: return get_jnp_float8_e5m2_type(), get_jnp_float8_e5m2_type() - if format_ == recipe.Format.HYBRID: + if format_ == Format.HYBRID: return get_jnp_float8_e4m3_type(), get_jnp_float8_e5m2_type() + if format_ == Format.E2M1: + return jnp.float4_e2m1fn, jnp.float4_e2m1fn return jnp.bfloat16, jnp.bfloat16 @@ -209,7 +290,6 @@ class BaseQuantizeConfig(ABC): INITIALIZED: Whether the config has been initialized MARGIN: Margin value for quantization COLLECTION_NAME: Name of the collection for quantization metadata - FP8_FORMAT: FP8 format to use FWD_DTYPE: Forward pass data type BWD_DTYPE: Backward pass data type FP8_2X_ACC_FPROP: Whether to use 2x accumulation for forward pass @@ -223,28 +303,26 @@ class BaseQuantizeConfig(ABC): INITIALIZED = False MARGIN: float = 0.0 COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME - FP8_FORMAT: recipe.Format = recipe.Format.HYBRID - FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] - BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] + FWD_DTYPE: DType = None + BWD_DTYPE: DType = None FP8_2X_ACC_FPROP: bool = False FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False # DelayedScaling + # TODO(Phuong): move these two into DelayedScalingQuantizeConfig AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: - """Initialize the quantization configuration. + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: + """Initialize the quantization configuration from a given recipe. Args: fp8_recipe: The FP8 recipe to use for initialization """ self.INITIALIZED = True - self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 - self.FP8_FORMAT = fp8_recipe.fp8_format - self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT) + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp8_format) def is_fp8_enabled(self) -> bool: """Check if FP8 quantization is enabled. @@ -265,6 +343,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: The scaling mode for the specified usage type. """ + @abstractmethod + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + def is_supported(self) -> tuple[bool, str]: """Check if this QuantizeConfig class is supported on the available devices. @@ -277,7 +376,7 @@ def is_supported(self) -> tuple[bool, str]: kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: - is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) + is_supported, reason = is_scaling_mode_supported(scaling_mode=scaling_mode) if not is_supported: return is_supported, reason return True, None @@ -286,7 +385,7 @@ def is_supported(self) -> tuple[bool, str]: class NoOpQuantizeConfig(BaseQuantizeConfig): """Configuration class higher-precision non-quantized operation.""" - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize no-op configuration.""" raise NotImplementedError( "NoOpQuantizeConfig cannot be initialize from a recipe as it represents" @@ -297,6 +396,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.NO_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + class DelayedScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for delayed scaling FP8 recipe. @@ -305,7 +425,7 @@ class DelayedScalingQuantizeConfig(BaseQuantizeConfig): FP8 quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize delayed scaling FP8 configuration. Args: @@ -315,6 +435,7 @@ def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: AssertionError: If recipe parameters are not supported """ super().initialize_from_recipe(fp8_recipe) + self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 assert fp8_recipe.amax_compute_algo in [ "max", @@ -339,6 +460,46 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.DELAYED_TENSOR_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + scale = module.variable( + collection_name, + f"{quantizer_name}{postfix}_scale", + jnp.ones, + (1,), + jnp.float32, + ).value + amax_history = module.variable( + collection_name, + f"{quantizer_name}{postfix}_amax_history", + jnp.zeros, + (self.AMAX_HISTORY_LEN,), + jnp.float32, + ).value + return QuantizeMeta( + margin=self.MARGIN, + amax_compute_algo=self.AMAX_COMPUTE_ALGO, + scale=scale, + amax_history=amax_history, + ) + class CurrentScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for current scaling FP8 recipe. @@ -347,7 +508,7 @@ class CurrentScalingQuantizeConfig(BaseQuantizeConfig): FP8 quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize current scaling FP8 configuration. Args: @@ -360,6 +521,27 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.CURRENT_TENSOR_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + class BlockScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for block scaling FP8 recipe. @@ -368,7 +550,7 @@ class BlockScalingQuantizeConfig(BaseQuantizeConfig): FP8 quantization mode. """ - def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: """Initialize block scaling FP8 configuration. Args: @@ -381,59 +563,203 @@ def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: """Gets the scaling mode for a specific tensor's usage type.""" return ScalingMode.MXFP8_1D_SCALING + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + return QuantizeMeta() + + +@dataclass +class NVFP4ScalingQuantizeConfig(BaseQuantizeConfig): + """Configuration class for NVFP4 scaling recipe. + + This class provides specific initialization and finalization for NVFP4 scaling quantization mode. + """ + + DISABLE_STOCHASTIC_ROUNDING: bool = False + DISABLE_RHT: bool = False + DISABLE_2D_QUANTIZATION: bool = False + + def initialize_from_recipe(self, fp8_recipe: Recipe) -> None: + """Initialize block scaling NVFP4 configuration. -_QUANTIZE_CONFIG = NoOpQuantizeConfig() + Args: + fp8_recipe: The quantization recipe to use for initialization + """ + assert isinstance(fp8_recipe, NVFP4BlockScaling) + self.INITIALIZED = True + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(fp8_recipe.fp4_format) + self.AMAX_HISTORY_LEN = 0 + + self.DISABLE_STOCHASTIC_ROUNDING = fp8_recipe.disable_stochastic_rounding + self.DISABLE_RHT = fp8_recipe.disable_rht + self.DISABLE_2D_QUANTIZATION = fp8_recipe.disable_2d_quantization + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + if (not self.DISABLE_2D_QUANTIZATION) and tensor_source == TensorSource.KERNEL: + return ScalingMode.NVFP4_2D_SCALING + # for x and grad + return ScalingMode.NVFP4_1D_SCALING + + def _make_rht_quantize_meta(self, q_layout, tensor_source: TensorSource) -> QuantizeMeta: + """Create the quantization metadata for RHT if applicable.""" + # Imported here to prevent circular import + from transformer_engine.jax.quantize import QuantizeLayout + + use_rht = self.get_scaling_mode( + tensor_source + ) == ScalingMode.NVFP4_1D_SCALING and q_layout in { + QuantizeLayout.ROWWISE_COLWISE, + QuantizeLayout.COLWISE, + } + if self.DISABLE_RHT: + use_rht = False + return QuantizeMeta(use_rht=use_rht) + + def _make_stochastic_rounding_rng_state( + self, module, tensor_source: TensorSource, quantizer_name: str + ) -> jnp.ndarray: + """Create the stochastic rounding rng state if applicable.""" + if self.DISABLE_STOCHASTIC_ROUNDING: + return QuantizeMeta() + + if tensor_source != TensorSource.DGRAD: + # Only DGRAD uses stochastic rounding + return QuantizeMeta() + + sr_jax_rng = module.make_rng("sr_rng") + # Get a unique key for this quantizer + # Use hashlib to get a deterministic hash value for quantizer_name + quantizer_hash = ( + int(hashlib.sha256(quantizer_name.encode("utf-8")).hexdigest(), 16) + % jnp.iinfo(jnp.int32).max + ) + sr_jax_rng = jax.jit(jax.random.fold_in)(sr_jax_rng, quantizer_hash) + + # Generate 4 random uint32 values per device from the JAX PRNG key + shape = (get_num_devices_in_mesh(), 4) + sr_jax_rng_state = jax.random.randint( + sr_jax_rng, shape, 0, jnp.iinfo(jnp.int32).max, dtype=jnp.int32 + ).view(jnp.uint32) + sr_jax_rng_state = with_sharding_constraint( + sr_jax_rng_state, jax.sharding.PartitionSpec(get_all_mesh_axes(), None) + ) + return QuantizeMeta(stochastic_rounding_rng_state=sr_jax_rng_state) + + def get_quantize_flax_meta( + self, + module, + collection_name: str, + postfix: str, + tensor_source: TensorSource, + quantizer_name: str, + ) -> QuantizeMeta: + """Get the quantization metadata for a given Flax module. + + Args: + module: The Flax module to get metadata for + collection_name: The name of the collection to store metadata in + postfix: Postfix to append to metadata names + tensor_source: The source type of the tensor (e.g., X, KERNEL, DGRAD) + quantizer_name: The name of the quantizer within the module + Returns: + The quantization metadata for the specified module and tensor. It can be empty if no metadata is needed. + """ + # Imported here to prevent circular import + from transformer_engine.jax.quantize import QuantizeLayout -def get_quantize_config(): - """Global instance of BaseQuantizeConfig set by fp8_autocast context.""" - return _QUANTIZE_CONFIG + return QuantizeMeta.merge( + self._make_rht_quantize_meta(QuantizeLayout.ROWWISE_COLWISE, tensor_source), + self._make_stochastic_rounding_rng_state(module, tensor_source, quantizer_name), + ) def get_quantize_config_class( - fp8_recipe: recipe.Recipe, + fp8_recipe: Recipe, ) -> Type[BaseQuantizeConfig]: - """Get the quantization configuration based on the FP8 recipe. + """Get the quantization configuration class based on the FP8 recipe. Args: fp8_recipe: The FP8 recipe to use for initialization Returns: The quantization config class corresponding to the given recipe. """ - if isinstance(fp8_recipe, recipe.DelayedScaling): + if fp8_recipe is None: + return NoOpQuantizeConfig + if isinstance(fp8_recipe, DelayedScaling): return DelayedScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + if isinstance(fp8_recipe, MXFP8BlockScaling): return BlockScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + if isinstance(fp8_recipe, Float8CurrentScaling): return CurrentScalingQuantizeConfig + if isinstance(fp8_recipe, NVFP4BlockScaling): + return NVFP4ScalingQuantizeConfig raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") +def get_quantize_config_with_recipe(fp8_recipe: Recipe): + """Get the quantization configuration object based on the FP8 recipe.""" + config = get_quantize_config_class(fp8_recipe)() + if fp8_recipe is not None: + config.initialize_from_recipe(fp8_recipe) + return config + + +_GLOBAL_RECIPE: Optional[Recipe] = None + + +def get_global_quantize_recipe() -> Optional[Recipe]: + """Get the global quantization recipe if set. + + Returns: + The global quantization recipe or None if not set. + """ + return _GLOBAL_RECIPE + + @contextmanager -def fp8_autocast( +def autocast( enabled: bool = False, - fp8_recipe: Optional[recipe.Recipe] = None, + recipe: Optional[Recipe] = None, mesh_resource: Optional[MeshResource] = None, ) -> None: - r"""Context manager for FP8 automatic mixed precision. + r"""Context manager for FP8 or FP4 usage. - This context manager enables FP8 quantization for the duration of its context. + This context manager enables quantization for the duration of its context. .. code-block:: python - mesh_shape = (4, 2) - dp_mesh_axis_name = 'data_parallel' - tp_mesh_axis_name = 'tensor_parallel' - devices = np.asarray(jax.devices()).reshape(*mesh_shape) + mesh_shape = (4, 2) + dp_mesh_axis_name = 'data_parallel' + tp_mesh_axis_name = 'tensor_parallel' + devices = np.asarray(jax.devices()).reshape(*mesh_shape) - with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): - mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) + with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): + mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) - with fp8_autocast(enabled=True, mesh_resource=mesh_resource): - rules = extend_logical_axis_rules(tuple()) - transformer = TransformerLayer() + with autocast(enabled=True, mesh_resource=mesh_resource): + rules = extend_logical_axis_rules(tuple()) + transformer = TransformerLayer() - with partitioning.axis_rules(rules): - pjit(transformer.init, ...)(...) + with partitioning.axis_rules(rules): + pjit(transformer.init, ...)(...) .. note:: We only support :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len`, @@ -445,58 +771,62 @@ def fp8_autocast( ---------- enabled: bool, default = False Whether or not to enable fp8 - fp8_recipe: recipe.DelayedScaling, default = None - Recipe used for FP8 training. + recipe: recipe.DelayedScaling, default = None + recipe used for low precision quantization. mesh_resource: MeshResource, default = None Specify the mesh axes for data and tensor parallelism to shard along. If set to None, then no data or tensor parallelism will be used. """ - if fp8_recipe is None: - fp8_recipe = recipe.DelayedScaling() + if recipe is None: + recipe = DelayedScaling() - global _QUANTIZE_CONFIG + global _GLOBAL_RECIPE - old_quantize_config = _QUANTIZE_CONFIG + old_global_recipe = _GLOBAL_RECIPE - _QUANTIZE_CONFIG = NoOpQuantizeConfig() + _GLOBAL_RECIPE = None try: with global_shard_guard(mesh_resource): if enabled: - _QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() - is_supported, reason = _QUANTIZE_CONFIG.is_supported() + _GLOBAL_RECIPE = recipe + is_supported, reason = get_quantize_config_class(_GLOBAL_RECIPE)().is_supported() assert is_supported, reason - _QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) yield finally: - _QUANTIZE_CONFIG = old_quantize_config + _GLOBAL_RECIPE = old_global_recipe -def get_delayed_scaling(): - r""" - Obtain an instance of DelayedScaling which is set via fp8_autocast. +@contextmanager +def fp8_autocast( + enabled: bool = False, + fp8_recipe: Optional[Recipe] = None, + mesh_resource: Optional[MeshResource] = None, +) -> None: + """ + .. warning:: + + fp8_autocast is deprecated and will be removed in a future release. + Use autocast(enabled=..., recipe=..., mesh_resource=...) instead. - .. note:: - We only store :attr:`margin`, :attr:`fp8_format`, :attr:`amax_history_len` - , and :attr:`amax_compute_algo` via fp8_autocast. Other parameters in - recipe.DelayedScaling would be returned as the default values. - - Returns - ------- - delay_scaling : DelayedScaling - an instance of DelayedScaling which is set via fp8_autocast. """ - amax_compute_algo = ( - "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" - ) - return recipe.DelayedScaling( - margin=int(get_quantize_config().MARGIN), - fp8_format=get_quantize_config().FP8_FORMAT, - amax_history_len=get_quantize_config().AMAX_HISTORY_LEN, - amax_compute_algo=amax_compute_algo, + + warnings.warn( + "fp8_autocast is deprecated and will be removed in a future release. " + "Use autocast(enabled=..., recipe=..., mesh_resource=...) instead.", + category=DeprecationWarning, + stacklevel=2, ) + # Call new implementation. + with autocast( + enabled=enabled, + recipe=fp8_recipe, + mesh_resource=mesh_resource, + ): + yield + def update_collections(new: Collection, original: Collection) -> Collection: r"""Update collections with new values while preserving original structure. diff --git a/transformer_engine/jax/quantize/metadata.py b/transformer_engine/jax/quantize/metadata.py index 637450216..a987643eb 100644 --- a/transformer_engine/jax/quantize/metadata.py +++ b/transformer_engine/jax/quantize/metadata.py @@ -9,23 +9,49 @@ scale factors and amax history for different tensor types. """ from dataclasses import dataclass -import jax.numpy as jnp __all__ = ["QuantizeMeta", "QuantizeMetaSet"] -@dataclass class QuantizeMeta: """Metadata for quantization parameters. - Attributes: + For Delayed Scaling recipe: scale: The scaling factor for quantization amax_history: History of maximum absolute values + + For NVFP4 recipe with Stochastic Rounding: + sr_rng_state: The state of the stochastic rounding RNG + """ - scale: jnp.ndarray - amax_history: jnp.ndarray + @staticmethod + def merge(a: "QuantizeMeta", b: "QuantizeMeta") -> "QuantizeMeta": + """Merge two QuantizeMeta instances. + + Args: + a (QuantizeMeta): The first QuantizeMeta instance. + b (QuantizeMeta): The second QuantizeMeta instance. + + Returns: + QuantizeMeta: A new QuantizeMeta instance with merged metadata. + """ + assert isinstance(a, QuantizeMeta) + assert isinstance(b, QuantizeMeta) + for key in b.get_kwargs_dictionary().keys(): + if key in a.get_kwargs_dictionary(): + assert ( + a.get_kwargs_dictionary()[key] == b.get_kwargs_dictionary()[key] + ), f"Conflict in merging QuantizeMeta: {key} has different values." + return QuantizeMeta(**{**a.get_kwargs_dictionary(), **b.get_kwargs_dictionary()}) + + def __init__(self, **kwargs): + self._kwargs = kwargs + + def get_kwargs_dictionary(self): + """Get the metadata as a dictionary.""" + return self._kwargs @dataclass diff --git a/transformer_engine/jax/quantize/misc.py b/transformer_engine/jax/quantize/misc.py new file mode 100644 index 000000000..c1e169d00 --- /dev/null +++ b/transformer_engine/jax/quantize/misc.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +""" +This module provides additional enum and utilities for quantizing tensors in JAX. +""" +from dataclasses import dataclass +from enum import Enum + +from transformer_engine_jax import JAXX_Quantize_Layout + +__all__ = [ + "QuantizeLayout", +] + + +@dataclass(frozen=True) +class QuantizeLayout(Enum): + "Wrapper for JAXX_Quantize_Layout" + + ROWWISE = JAXX_Quantize_Layout.ROWWISE + COLWISE = JAXX_Quantize_Layout.COLWISE + ROWWISE_COLWISE = JAXX_Quantize_Layout.ROWWISE_COLWISE + + @property + def has_rowwise(self) -> bool: + """If the layout has the rowwise component""" + return self.value in (JAXX_Quantize_Layout.ROWWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE) + + @property + def has_colwise(self) -> bool: + """If the layout has the colwise component""" + return self.value in (JAXX_Quantize_Layout.COLWISE, JAXX_Quantize_Layout.ROWWISE_COLWISE) + + @property + def is_rowwise_colwise(self) -> bool: + """If layout is both rowwise and colwise""" + return self.value == JAXX_Quantize_Layout.ROWWISE_COLWISE + + @property + def is_rowwise_only(self) -> bool: + """If layout is rowwise only""" + return self.value == JAXX_Quantize_Layout.ROWWISE + + @property + def is_colwise_only(self) -> bool: + """If layout is colwise only""" + return self.value == JAXX_Quantize_Layout.COLWISE + + def __eq__(self, other): + """Compare this quantize layout with another. + + Args: + other: The other quantize layout to compare with + + Returns: + True if the modes are equal, False otherwise + """ + if not isinstance(other, QuantizeLayout): + return False + return self.value == other.value diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 306603bbe..4edc18779 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -15,10 +15,11 @@ import jax import jax.numpy as jnp from jax.tree_util import register_pytree_node_class -from transformer_engine_jax import QuantizeLayout from transformer_engine.common import recipe from .scaling_modes import ScalingMode +from .misc import QuantizeLayout +from .hadamard import apply_rht from .tensor import ( ScaledTensor, ScaledTensor1x, @@ -27,15 +28,15 @@ NoScaleTensor, ) from .helper import ( - get_quantize_config, - get_quantize_config_class, + get_global_quantize_recipe, + get_quantize_config_with_recipe, AmaxComputeAlgo, TensorSource, ) from .device_utils import is_fp8_gemm_with_all_layouts_supported +from ..sharding import get_num_devices_in_mesh __all__ = [ - "QuantizeLayout", "Quantizer", "QuantizerSet", "CurrentScaleQuantizer", @@ -49,7 +50,7 @@ def compute_scale_from_amax( - amax: jnp.ndarray, q_dtype: jnp.dtype, scale: Optional[jnp.ndarray] = None + amax: jnp.ndarray, q_dtype: jnp.dtype, margin: float, scale: Optional[jnp.ndarray] = None ) -> jnp.ndarray: """Compute scale from amax value. @@ -63,9 +64,10 @@ def compute_scale_from_amax( fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) if scale is None: scale = jnp.ones((1,)) - sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) + sf = (fp8_max / amax) / (2**margin) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) + assert sf.shape == (1,), f"Expected sf.shape == (1,), but got {sf.shape}" return sf @@ -81,12 +83,15 @@ class Quantizer(ABC): q_dtype: The data type for quantized values scaling_mode: The scaling mode to use for quantization q_layout: The quantization axis (row-wise, column-wise, or both) + data_layout: The data layout string (e.g., "NT") + checkpoint_name: Optional name for checkpointing quantization state """ q_dtype: jnp.dtype scaling_mode: ScalingMode q_layout: QuantizeLayout data_layout: str + checkpoint_name: Optional[str] = None def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -95,7 +100,13 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = () - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + ) return (children, aux_data) @classmethod @@ -115,14 +126,6 @@ def update(self, *args, **kwargs): """Update quantizer state (no-op in base class).""" del args, kwargs - def is_2x2x(self) -> bool: - """Check if quantizer uses both row-wise and column-wise quantization. - - Returns: - True if using both row-wise and column-wise quantization - """ - return self.q_layout == QuantizeLayout.ROWWISE_COLWISE - def get_data_layout(self) -> str: """Get the data data_layout string. @@ -132,11 +135,11 @@ def get_data_layout(self) -> str: Raises: ValueError: If quantization axis is invalid """ - if self.q_layout == QuantizeLayout.ROWWISE_COLWISE: + if self.q_layout.is_rowwise_colwise: return self.data_layout - if self.q_layout == QuantizeLayout.ROWWISE: + if self.q_layout.is_rowwise_only: return self.data_layout[0] - if self.q_layout == QuantizeLayout.COLWISE: + if self.q_layout.is_colwise_only: return self.data_layout[1] raise ValueError(f"Invalid q_layout: {self.q_layout}") @@ -155,7 +158,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> """ def quantize( - self, x, is_rowwise=False, is_colwise=False, dq_dtype=None, flatten_axis=-1, **kwargs + self, x, is_rowwise=None, is_colwise=None, dq_dtype=None, flatten_axis=-1, **kwargs ) -> ScaledTensor: """Quantize a tensor using the internal _quantize_func(). @@ -170,7 +173,11 @@ def quantize( A ScaledTensor1x or ScaledTensor2x containing the quantized data """ del kwargs - if (is_rowwise and is_colwise) or self.is_2x2x(): + + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise + + if is_rowwise and is_colwise: rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = self._quantize_func( x, is_colwise=True, dq_dtype=dq_dtype, flatten_axis=flatten_axis @@ -216,6 +223,7 @@ class CurrentScaleQuantizer(Quantizer): Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING q_layout: Quantization axis (default: ROWWISE_COLWISE) + data_layout: Data layout string (default: "NT") """ scaling_mode: ScalingMode = ScalingMode.CURRENT_TENSOR_SCALING @@ -247,8 +255,7 @@ def _quantize_func( compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) - fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) - scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) + scale = compute_scale_from_amax(amax, self.q_dtype, margin=0.0) scaled_x = x.data.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) @@ -284,16 +291,8 @@ def quantize( flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise rowwise_tensor = self._quantize_func(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) colwise_tensor = None @@ -328,17 +327,23 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): Attributes: scaling_mode: Set to NVTE_DELAYED_TENSOR_SCALING q_layout: Quantization axis (default: ROWWISE_COLWISE) + data_layout: Data layout string (default: "NT") + margin: Margin value for scale computation + amax_compute_algo: Algorithm for computing amax scale: Current scaling factor amax_history: History of maximum absolute values """ - scaling_mode: ScalingMode = ScalingMode.DELAYED_TENSOR_SCALING - q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE + margin: float = 0.0 + amax_compute_algo: AmaxComputeAlgo = AmaxComputeAlgo.MAX scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) - amax_history: jnp.ndarray = field( - default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32) - ) + amax_history: jnp.ndarray = field(default_factory=lambda: jnp.zeros((1024,), jnp.float32)) + + def __post_init__(self): + assert self.margin is not None, "margin must be specified" + assert self.amax_compute_algo is not None, "amax_compute_algo must be specified" + assert self.amax_history is not None, "amax_history must be specified" def tree_flatten(self): """Flatten the quantizer for JAX tree operations. @@ -347,7 +352,15 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.scale, self.amax_history) - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + self.margin, + self.amax_compute_algo, + ) return (children, aux_data) def _quantize_func( @@ -380,6 +393,7 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + # Note, this updating of amax here will only be called once because the "quantize" method impl inherited from CurrentScaleQuantizer only calls _quantize_func once then transposes the result for colwise quantization. So we don't have to worry about update being called twice for 2x2x quantization. self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, @@ -401,12 +415,14 @@ def _update_amax_history(amax_history, new_amax): Returns: Updated AMAX history """ - amax_history = amax_history.at[0].set(new_amax[0]) + amax_history = amax_history.at[0].set(new_amax.reshape((1,))[0]) return amax_history @staticmethod - @partial(jax.jit, static_argnums=(2,)) - def _compute_scale(amax_history, scale, q_dtype): + @partial(jax.jit, static_argnums=(2, 3, 4)) + def _compute_scale( + amax_history, scale, q_dtype, amax_compute_algo: AmaxComputeAlgo, margin: float + ): """Compute new scale based on AMAX history. Args: @@ -418,12 +434,12 @@ def _compute_scale(amax_history, scale, q_dtype): Updated scale value """ # 2. Calculate the current scale - if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + if amax_compute_algo is AmaxComputeAlgo.MAX: amax = jnp.max(amax_history, axis=-1, keepdims=True) else: amax = amax_history[0:1] - return compute_scale_from_amax(amax, q_dtype, scale=scale) + return compute_scale_from_amax(amax, q_dtype, margin=margin, scale=scale) @staticmethod @jax.jit @@ -447,7 +463,9 @@ def update(self, new_amax: jnp.ndarray): new_amax: New maximum absolute value to add to history """ amax_history = self._update_amax_history(self.amax_history, new_amax) - self.scale = self._compute_scale(amax_history, self.scale, self.q_dtype) + self.scale = self._compute_scale( + amax_history, self.scale, self.q_dtype, self.amax_compute_algo, self.margin + ) self.amax_history = self._roll_and_reset_amax_history(amax_history) @@ -494,7 +512,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> dq_dtype = dq_dtype if dq_dtype is not None else x.dtype x_shape = x.shape scale_shape = self.scaling_mode.get_scale_shape( - x_shape, is_colwise, is_padded=False, flatten_axis=flatten_axis + x_shape, is_colwise=is_colwise, is_padded=False, flatten_axis=flatten_axis ) scale_dtype = self.scaling_mode.get_scale_dtype() x = x.reshape( @@ -563,6 +581,259 @@ def _e8m0_to_dtype(self, x, dtype): return new_x.astype(dtype) +@register_pytree_node_class +@dataclass +class NVFP4Quantizer(Quantizer): + """Quantizer implementation using current scaling. + + This quantizer uses current scaling mode with float32 scales + + Attributes: + scaling_mode: Set to NVFP4_1D_SCALING or NVFP4_2D_SCALING + q_layout: Quantization axis + data_layout: Data layout string (default: "NT") + stochastic_rounding_rng_state: RNG state for stochastic rounding, must be of shape (4,) and dtype uint32. If None, stochastic rounding is disabled. + use_rht: Whether to apply Randomized Hadamard Transform (RHT) before quantization. + """ + + scaling_mode: ScalingMode = ScalingMode.NVFP4_1D_SCALING + q_layout: QuantizeLayout = QuantizeLayout.ROWWISE_COLWISE + data_layout: str = "NT" + use_rht: bool = False + stochastic_rounding_rng_state: Optional[jnp.ndarray] = None + + def __post_init__(self): + assert ( + self.q_dtype == jnp.float4_e2m1fn + ), "NVFP4 quantization must use a q_dtype of float4_e2m1fn" + assert self.scaling_mode.is_nvfp4_scaling, "NVFP4Quantizer must use NVFP4 scaling modes" + + def tree_flatten(self): + """Flatten the quantizer for JAX tree operations. + + Returns: + Tuple of (children, aux_data) for tree operations + """ + children = (self.stochastic_rounding_rng_state,) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + self.use_rht, + ) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstruct a quantizer from its flattened representation. + + Args: + aux_data: Auxiliary data containing quantizer parameters + children: Unused children data + + Returns: + A reconstructed Quantizer instance + """ + stochastic_rounding_rng_state = children[0] + return cls(*aux_data, stochastic_rounding_rng_state=stochastic_rounding_rng_state) + + def _apply_stochastic_rounding(self, x): + assert ( + self.stochastic_rounding_rng_state is not None + ), "Stochastic rounding RNG state is not initialized" + expected_sr_rng_state_shape = (get_num_devices_in_mesh(), 4) + assert self.stochastic_rounding_rng_state.shape == expected_sr_rng_state_shape, ( + "Stochastic rounding RNG state must be of shape (num_devices_in_mesh, 4). Expected" + f" {expected_sr_rng_state_shape}, but got {self.stochastic_rounding_rng_state.shape}" + ) + assert ( + self.stochastic_rounding_rng_state.dtype == jnp.uint32 + ), "Stochastic rounding RNG state must be of dtype uint32" + + # Default RNG state in JAX expects 2x 32-bit integers, use first 2 uint32s for initial state and fold in the other 2 uint32s + key_bits = jnp.array( + [ + # only take the first device's RNG state as the pure-JAX stochastic rounding impl only uses a single-device + self.stochastic_rounding_rng_state[0][0], + self.stochastic_rounding_rng_state[0][1], + ], + dtype=jnp.uint32, + ) + key = jax.random.wrap_key_data(key_bits) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][2]) + key = jax.jit(jax.random.fold_in)(key, self.stochastic_rounding_rng_state[0][3]) + + abs_x = jnp.abs(x) + sign_x = jnp.sign(x) + + floor = ( + (abs_x >= 0.5) * 0.5 + + (abs_x >= 1) * 0.5 + + (abs_x >= 2) + + (abs_x >= 3) + + (abs_x >= 4) + + (abs_x >= 6) * 2 + ) + ceil = ( + 0.5 + + (abs_x > 0.5) * 0.5 + + (abs_x > 1) * 1 + + (abs_x > 2) + + (abs_x > 3) + + (abs_x > 4) * 2 + ) + frac = (abs_x - floor) / (ceil - floor) + + rand = jax.random.uniform(key, abs_x.shape) + return sign_x * jnp.where(frac >= rand, ceil, floor) + + def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> ScaledTensor1x: + """Quantize function helper for block scaling FP8. + + Args: + x: Input tensor to quantize + is_colwise: Whether to use column-wise quantization + dq_dtype: Data type for dequantized values + flatten_axis: The quantization axis for the tensor + + Returns: + A ScaledTensor1x containing the quantized data + """ + # TODO(Phuong): use quantize_func from JAX + if flatten_axis < 0: + flatten_axis = x.ndim + flatten_axis + assert ( + 0 <= flatten_axis < x.ndim + ), f"Invalid flatten_axis: {flatten_axis} for tensor of shape {x.shape}" + + should_apply_rht = self.scaling_mode == ScalingMode.NVFP4_1D_SCALING and is_colwise + + global_amax = None + if isinstance(x, NoScaleTensor): + global_amax = ( + x.amax if not should_apply_rht else None + ) # RHT changes the amax so don't use precalculated amax for colwise 1D nvfp4 quantization with RHT + x = x.data + + # Transpose if required + rowwise_flatten_axis = flatten_axis + data_layout = self.data_layout[0] + if is_colwise: + x = jnp.transpose(x, (*range(flatten_axis, x.ndim), *range(flatten_axis))) + data_layout = self.data_layout[1] + # convert flatten_axis from N layout to T layout + flatten_axis = x.ndim - flatten_axis + x_shape = x.shape + + # We currently only have a single flag 'use_rht' on the quantizer. To avoid an unused rowwise flag, we assume RHT is only used for colwise quantization for now. + use_rht = self.use_rht and is_colwise and self.scaling_mode == ScalingMode.NVFP4_1D_SCALING + if use_rht: + x = apply_rht(x) + + dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + scale_shape = self.scaling_mode.get_scale_shape( + x_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=rowwise_flatten_axis, + ) + scale_dtype = self.scaling_mode.get_scale_dtype() + x = x.reshape( + *x_shape[: flatten_axis - 1], + scale_shape[flatten_axis - 1], + int(x_shape[flatten_axis - 1] / scale_shape[flatten_axis - 1]), + *x_shape[flatten_axis:-1], + scale_shape[-1], + int(x_shape[-1] / scale_shape[-1]), + ) + + # Dtype max constants + DATA_DTYPE_MAX = jnp.finfo(self.q_dtype).max.astype(jnp.float32) + SCALE_DTYPE_MAX = jnp.finfo(scale_dtype).max.astype(jnp.float32) + + # Level 1: Current Tensor Scaling + global_amax = ( + global_amax + if global_amax is not None + else jnp.max(jnp.abs(x)).reshape((1,)).astype(jnp.float32) + ) + tensor_scale = DATA_DTYPE_MAX * SCALE_DTYPE_MAX / global_amax + tensor_scale = jnp.minimum( + tensor_scale, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32) + ) + tensor_scale = jnp.where( + tensor_scale == jnp.array(0.0, dtype=jnp.float32), + jnp.array(1.0, dtype=jnp.float32), + tensor_scale, + ) + tensor_scale_inv = 1.0 / tensor_scale + + # Level 2: Block Scaling + block_amax = jnp.max(jnp.abs(x), axis=(flatten_axis + 2 - 2, -1), keepdims=True).astype( + jnp.float32 + ) + block_scale_inv = jnp.divide(block_amax, DATA_DTYPE_MAX) + block_scale_inv = block_scale_inv * tensor_scale + block_scale_inv = jnp.minimum( + block_scale_inv, jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32) + ) + block_scale_inv = jnp.clip(block_scale_inv, -SCALE_DTYPE_MAX, SCALE_DTYPE_MAX) + # We cast block_scale_inv to scale_dtype here to account for any rounding during the cast. This will ensure the quantized data incorporates the rounded scale value into its computation so dequantization is accurate. + block_scale_inv = block_scale_inv.astype(scale_dtype) + # Note, with JIT jax removes this intermediate cast leading to slightly incorrect results during DQ and worse convergence to the original tensor during many samples of Q+SR->DQ. So we use reduce_precision to simulate the cast to scale_dtype. + assert scale_dtype == jnp.float8_e4m3fn, "Only float8_e4m3fn is supported for scale_dtype" + block_scale_inv = jax.lax.reduce_precision(block_scale_inv, 4, 3) + block_scale = jnp.minimum( + jnp.divide(1.0, block_scale_inv.astype(jnp.float32) * tensor_scale_inv), + jnp.array(jnp.finfo(jnp.float32).max, dtype=jnp.float32), + ) + + # Apply scaling + scaled_x = x.astype(jnp.float32) * block_scale + if self.stochastic_rounding_rng_state is not None: + scaled_x = self._apply_stochastic_rounding(scaled_x) + clipped_x = jnp.clip(scaled_x, -DATA_DTYPE_MAX, DATA_DTYPE_MAX) + + # Cast to the right dtype + quantized_data = clipped_x.reshape(x_shape).astype(self.q_dtype) + block_scale_inv = block_scale_inv.reshape(scale_shape).astype(scale_dtype) + + # In the 2D scaling mode, the scale shape is 2D but it needs to be broadcasted to 1D for GEMM. + # TODO(Phuong): expose this broadcast_2d_scale_shape_to_1d option to the + # quantizer.quantize() API + broadcasted_1d_scale_shape = self.scaling_mode.get_scale_shape( + x_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=False, + flatten_axis=rowwise_flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + + # Broadcast and tile x to match the target shape + def repeat_to_shape(x, target_shape): + x_shape = x.shape + reps = [int(t // s) for s, t in zip(x_shape, target_shape)] + return jnp.tile(x, reps) + + block_scale_inv = repeat_to_shape(block_scale_inv, broadcasted_1d_scale_shape) + + return ScaledTensorFactory.create_1x( + data=quantized_data, + data_layout=data_layout, + is_colwise=is_colwise, + scale_inv=block_scale_inv, + amax=global_amax, + scaling_mode=self.scaling_mode, + dq_dtype=dq_dtype, + flatten_axis=rowwise_flatten_axis, + has_rht_applied=use_rht, + ) + + @register_pytree_node_class @dataclass class QuantizerSet: @@ -630,7 +901,14 @@ def tree_flatten(self): Tuple of (children, aux_data) for tree operations """ children = (self.quantizers,) - aux_data = (self.q_dtype, self.scaling_mode, self.q_layout, self.data_layout, self.n_groups) + aux_data = ( + self.q_dtype, + self.scaling_mode, + self.q_layout, + self.data_layout, + self.checkpoint_name, + self.n_groups, + ) return (children, aux_data) def __post_init__(self): @@ -712,16 +990,8 @@ def quantize( flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" - is_rowwise = ( - is_rowwise - if is_rowwise is not None - else (self.q_layout == QuantizeLayout.ROWWISE or self.is_2x2x()) - ) - is_colwise = ( - is_colwise - if is_colwise is not None - else (self.q_layout == QuantizeLayout.COLWISE or self.is_2x2x()) - ) + is_rowwise = is_rowwise if is_rowwise is not None else self.q_layout.has_rowwise + is_colwise = is_colwise if is_colwise is not None else self.q_layout.has_colwise assert is_rowwise or is_colwise, "No quantization layout is specified" original_shape = x.shape @@ -801,6 +1071,8 @@ class QuantizerFactory: ScalingMode.DELAYED_TENSOR_SCALING: DelayedScaleQuantizer, ScalingMode.CURRENT_TENSOR_SCALING: CurrentScaleQuantizer, ScalingMode.MXFP8_1D_SCALING: BlockScaleQuantizer, + ScalingMode.NVFP4_1D_SCALING: NVFP4Quantizer, + ScalingMode.NVFP4_2D_SCALING: NVFP4Quantizer, } @staticmethod @@ -810,6 +1082,7 @@ def create( q_dtype: jnp.dtype = None, q_layout: QuantizeLayout = None, n_groups: int = None, + checkpoint_name: Optional[str] = None, **kwargs, ) -> Quantizer: """Create one or more quantizers with specified parameters. @@ -821,12 +1094,12 @@ def create( q_layout: Quantization axis flatten_axis: The quantization axis for the tensor n_groups: Number of quantizers if GroupedQuantizer + checkpoint_name: Optional name for checkpointing quantizations **kwargs: Additional arguments for quantizer initialization Returns: A single quantizer or tuple of quantizers """ - # (Phuong): add this assert back when NVTE_NO_SCALING is fully implememted assert isinstance(scaling_mode, ScalingMode), "Invalid scaling_mode type" if n_groups: if n_quantizers != 1: @@ -845,7 +1118,11 @@ def create( for _ in range(n_quantizers): quantizers.append( quantizer_type( - q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout, **kwargs + q_dtype=q_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + checkpoint_name=checkpoint_name, + **kwargs, ) ) return quantizers[0] if len(quantizers) == 1 else tuple(quantizers) @@ -859,6 +1136,8 @@ def _create_set( bwd_dtype, is_2x2x, n_groups, + is_inference_mode=False, + checkpoint_name: Optional[str] = None, **kwargs, ) -> QuantizerSet: """Create a set of quantizers for forward and backward passes. @@ -871,6 +1150,8 @@ def _create_set( bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization n_groups + is_inference_mode: Whether to create quantizers for inference mode. This option is not fully supported yet + checkpoint_name: Optional name for checkpointing quantizations **kwargs: Additional arguments for quantizer initialization Returns: @@ -882,32 +1163,43 @@ def _create_set( q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE if kernel_scaling_mode.is_1d_block_scaling(): q_layout_kernel = QuantizeLayout.COLWISE - if get_quantize_config().INFERENCE_MODE: + if is_inference_mode: q_layout_dgrad = None if "quantize_meta_set" in kwargs: quantize_meta_set = kwargs.get("quantize_meta_set") - args_x = { - "scale": quantize_meta_set.x.scale, - "amax_history": quantize_meta_set.x.amax_history, - } - args_kernel = { - "scale": quantize_meta_set.kernel.scale, - "amax_history": quantize_meta_set.kernel.amax_history, - } - args_grad = { - "scale": quantize_meta_set.grad.scale, - "amax_history": quantize_meta_set.grad.amax_history, - } + args_x = quantize_meta_set.x.get_kwargs_dictionary() + args_kernel = quantize_meta_set.kernel.get_kwargs_dictionary() + args_grad = quantize_meta_set.grad.get_kwargs_dictionary() else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) + q_x = QuantizerFactory.create( + 1, + x_scaling_mode, + fwd_dtype, + q_layout_x, + n_groups, + checkpoint_name=checkpoint_name, + **args_x, + ) q_kernel = QuantizerFactory.create( - 1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel + 1, + kernel_scaling_mode, + fwd_dtype, + q_layout_kernel, + n_groups, + checkpoint_name=checkpoint_name, + **args_kernel, ) q_dgrad = QuantizerFactory.create( - 1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad + 1, + grad_scaling_mode, + bwd_dtype, + q_layout_dgrad, + n_groups, + checkpoint_name=checkpoint_name, + **args_grad, ) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @@ -919,6 +1211,8 @@ def create_set( bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, n_groups: int = None, + checkpoint_name: Optional[str] = None, + # TODO(jberchtold): rename fp8_recipe to quantization_recipe fp8_recipe: Optional[recipe.Recipe] = None, **kwargs, ) -> tuple[Union[tuple[Quantizer], None]]: @@ -926,11 +1220,12 @@ def create_set( Args: n_quantizer_sets: Number of quantizer sets to create - scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode - fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE - bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE - is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X + scaling_mode: Scaling mode to use, default is get the scaling mode from the specified or global recipe + fwd_dtype: Data type for forward pass, default is get the fwd dtype from the specified or global recipe + bwd_dtype: Data type for backward pass, default is get the bwd dtype from the specified or global recipe + is_2x2x: Whether to use 2x2x quantization, default is determined based on the specified or global recipe n_groups: + checkpoint_name: Optional name for checkpointing quantizations fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. **kwargs: Additional arguments for quantizer initialization @@ -945,22 +1240,46 @@ def create_set( " scaling mode differs between x, kernel, and grad in the quantizer set." ) + # TODO(jberchtold): Currently this is a limitation because we only support automatically populating quantizer fields based on a given recipe when using Flax. In the generic quantizer logic, we cannot assume Flax is being used, so we require the user to provide the quantize_meta_set created by quantize_config.get_quantize_flax_meta() or the same data created by themselves if they are passing a recipe here directly. + assert ( + fp8_recipe is None or "quantize_meta_set" in kwargs + ), "When fp8_recipe is specified, quantize_meta_set must be provided in kwargs." + + if fp8_recipe is None: + fp8_recipe = get_global_quantize_recipe() + if fp8_recipe is not None: - quantize_config = get_quantize_config_class(fp8_recipe)() + assert scaling_mode is None, ( + "scaling_mode should not be specified when fp8_recipe is provided either directly" + " or through an autocast context." + ) + assert fwd_dtype is None, ( + "fwd_dtype should not be specified when fp8_recipe is provided either directly or" + " through an autocast context." + ) + assert bwd_dtype is None, ( + "bwd_dtype should not be specified when fp8_recipe is provided either directly or" + " through an autocast context." + ) + quantize_config = get_quantize_config_with_recipe(fp8_recipe) x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) - elif scaling_mode is not None: - x_scaling_mode = scaling_mode - kernel_scaling_mode = scaling_mode - grad_scaling_mode = scaling_mode + fwd_dtype = quantize_config.FWD_DTYPE + bwd_dtype = quantize_config.BWD_DTYPE + is_inference_mode = quantize_config.INFERENCE_MODE else: - x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) - kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) - grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + if scaling_mode is not None: + x_scaling_mode = scaling_mode + kernel_scaling_mode = scaling_mode + grad_scaling_mode = scaling_mode + else: + # TODO(jberchtold): make a way to explicitly pass a no scaling recipe here if we need other quantization config attributes in the future since NoOpQuantizeConfig already exists, we just can't use it here with direct recipe passing because we cannot differentiate between fp8_recipe=None meaning no recipe specified vs explicitly no quantization desired. + x_scaling_mode = ScalingMode.NO_SCALING + kernel_scaling_mode = ScalingMode.NO_SCALING + grad_scaling_mode = ScalingMode.NO_SCALING + is_inference_mode = False - fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE - bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: # TODO(Jeremy): check x, kernel, grad separately for 2x if x_scaling_mode.is_1d_block_scaling(): @@ -969,7 +1288,6 @@ def create_set( is_2x2x = not is_fp8_gemm_with_all_layouts_supported() else: # NO_SCALING ignores is_2x2x for now is_2x2x = False - is_inference_mode = get_quantize_config().INFERENCE_MODE assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] @@ -983,6 +1301,8 @@ def create_set( bwd_dtype=bwd_dtype, is_2x2x=is_2x2x, n_groups=n_groups, + is_inference_mode=is_inference_mode, + checkpoint_name=checkpoint_name, **kwargs, ) ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index e81a614f0..eea27a35d 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -17,11 +17,12 @@ import operator import numpy as np -from jax.experimental.custom_partitioning import BATCHING +from jax.experimental.custom_partitioning import BATCHING, CompoundFactor from jax.tree_util import register_pytree_node_class import jax.numpy as jnp -from transformer_engine_jax import JAXX_Scaling_Mode, QuantizeLayout +from transformer_engine_jax import JAXX_Scaling_Mode +from .misc import QuantizeLayout from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -72,16 +73,18 @@ class QuantizeShardyRules: Attributes: input_spec: Specification for the input axes - rowwise_rule: Sharding rule for the row-wise scale tensor, depends on - the axes in `input_spec` - colwise_rule: Likewise for the column-wise scale tensor. - factor_sizes: For block scaling, contains the block size factor, which is - used in `input_spec`. + rowwise_out_spec: Sharding spec for the rowwise quantized data + rowwise_scale_spec: Sharding spec for the rowwise scale + colwise_out_spec: Sharding spec for the colwise quantized data + colwise_scale_spec: Sharding spec for the colwise scale + factor_sizes: For block scaling, contains the block size factor """ input_spec: Tuple[str] - rowwise_rule: Tuple[str] - colwise_rule: Tuple[str] + rowwise_out_spec: Tuple[str] + rowwise_scale_spec: Tuple[str] + colwise_out_spec: Tuple[str] + colwise_scale_spec: Tuple[str] factor_sizes: Dict[str, int] @@ -100,10 +103,19 @@ def get_scale_dtype(self) -> jnp.dtype: The data type used for scale tensors """ + @abstractmethod + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + @abstractmethod def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, @@ -112,6 +124,7 @@ def get_scale_shape( Args: data_shape: The shape of the tensor being quantized + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) @@ -152,14 +165,23 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: @abstractmethod def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization + q_layout: The layout of the quantized tensor + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. + is_colwise_transposed: Whether the column-wise tensors are transposed. Returns: The Shardy rules for the scaling mode @@ -180,12 +202,22 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float32 + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return "NN" + def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = True, ) -> Tuple[int, ...]: """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. @@ -198,7 +230,14 @@ def get_scale_shape( Returns: The shape for scale tensors - (1,) """ - del data_shape, is_colwise, is_padded, flatten_axis + del ( + data_shape, + data_layout, + is_colwise, + is_padded, + flatten_axis, + broadcast_2d_scale_shape_to_1d, + ) return (0,) @lru_cache(maxsize=4) @@ -232,22 +271,36 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. + flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Returns: The Shardy rules for the scaling mode """ - del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) - scale_var = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + del broadcast_2d_scale_shape_to_1d + input_spec = tuple(f"{unique_var}_x_{i}" for i in range(len(input_shape))) + output_spec = tuple(input_spec) + return QuantizeShardyRules( + input_spec, + output_spec, + (BATCHING + f"{unique_var}_scale",), + (BATCHING + f"{unique_var}_colwise_output",), + (BATCHING + f"{unique_var}_colwise_scale",), + {}, + ) class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): @@ -264,25 +317,37 @@ def get_scale_dtype(self) -> jnp.dtype: """ return jnp.float32 + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return "NT" + def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = True, ) -> Tuple[int, ...]: """Get the shape for scale tensors in delayed scaling. Args: data_shape: The shape of the tensor being scaled + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True. Returns: The shape for scale tensors - (1,) """ - del is_colwise + del data_layout, is_colwise, broadcast_2d_scale_shape_to_1d if np.prod(data_shape) == 0: return (0,) return (1,) @@ -323,22 +388,41 @@ def get_grouped_scale_shape( return (n_groups,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - flatten_axis: Axis along which data can be flattened to 2D for quantization. - + flatten_axis: Axis along which data can be flattened to 2D for quantization + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. + q_layout: The layout of the quantized tensor + is_colwise_transposed: Whether the colwise scaling is transposed Returns: The Shardy rules for the scaling mode """ - del flatten_axis - input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) - scale_var = BATCHING + unique_var + "_scale_inv" - return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + del broadcast_2d_scale_shape_to_1d + input_spec = tuple(f"{unique_var}x_{i}" for i in range(len(input_shape))) + output_spec = input_spec + colwise_output_spec = (BATCHING + f"{unique_var}_colwise_output",) + + if q_layout.has_colwise: + from ..cpp_extensions.misc import multidim_transpose + + colwise_output_spec = input_spec + if is_colwise_transposed: + colwise_output_spec = multidim_transpose( + colwise_output_spec, transpose_axis=flatten_axis + ) + scale = (BATCHING + unique_var + "_scale_inv",) + return QuantizeShardyRules(input_spec, output_spec, scale, colwise_output_spec, scale, {}) class DelayedScalingModeMetadataImpl(CurrentScalingModeMetadataImpl): @@ -359,14 +443,18 @@ class BlockScalingModeMetadataImpl(ScalingModeMetadataImpl): _block_alignment: Alignment requirements for blocks """ - def __init__(self, block_dims: Tuple[int]): + def __init__(self, block_dims: Tuple[int], scale_dtype: jnp.dtype, data_layout: str): """Initialize block scaling mode implementation. Args: block_dims: Dimensions of the scaling blocks + scale_dtype: Data type of the scale tensor + data_layout: Layout for rowwise and colwise scaling, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. """ self._block_dims = block_dims + self._scale_dtype = scale_dtype self._block_alignment = (128, 4) + self._data_layout = data_layout def get_scale_dtype(self) -> jnp.dtype: """Get the data type for scale tensors in block scaling. @@ -374,7 +462,15 @@ def get_scale_dtype(self) -> jnp.dtype: Returns: The data type used for scale tensors (float8_e8m0fnu) """ - return jnp.float8_e8m0fnu + return self._scale_dtype + + def get_data_layout(self) -> str: + """Get the data layout for rowwise and colwise scaling. + + Returns: + The data layout, two characters, e.g. "NT", where each is either "N" (default) or "T" for transposed. The first character refers to the rowwise layout and the second refers to the colwise layout. + """ + return self._data_layout def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_dim): """Remove excess padding from the scale shape and return the shape with respect to the original data shape.""" @@ -402,23 +498,51 @@ def _apply_scale_shape_correction(self, data_shape, n_scale_blocks, scale_block_ def get_scale_shape( self, data_shape: Tuple[int, ...], + data_layout: str = "N", is_colwise: bool = False, is_padded: bool = True, flatten_axis: int = -1, + broadcast_2d_scale_shape_to_1d: bool = False, ) -> Tuple[int, ...]: """Get the shape for scale tensors in block scaling. Args: data_shape: The shape of the tensor being quantized + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether the scaling is column-wise is_padded: Whether to return padded shape flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to True. Returns: The shape for scale tensors """ + flatten_axis = (len(data_shape) + flatten_axis) % len(data_shape) + assert ( + 0 < flatten_axis < len(data_shape) + ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" + block_alignment = self._block_alignment if is_padded else (1, 1) + if is_colwise: + assert data_layout == self._data_layout[1], ( + f"Data layout must match colwise layout, received {data_layout} but expected" + f" {self._data_layout[1]}" + ) + else: + assert data_layout == self._data_layout[0], ( + f"Data layout must match rowwise layout, received {data_layout} but expected" + f" {self._data_layout[0]}" + ) + + if is_colwise and self._data_layout[1] == "T": + # TODO(Phuong): rework this hack so that we don't implicitly change is_colwise value + is_colwise = False # now rowwise in T is colwise in N + if flatten_axis < 0: + flatten_axis = len(data_shape) + flatten_axis + # flatten_axis is given wrt N layout, convert to T layout + flatten_axis = len(data_shape) - flatten_axis + if is_colwise: block_y, block_x = self._block_dims alignment_y, alignment_x = block_alignment @@ -426,12 +550,7 @@ def get_scale_shape( block_x, block_y = self._block_dims alignment_x, alignment_y = block_alignment - if flatten_axis < 0: - flatten_axis = len(data_shape) + flatten_axis - assert ( - 0 < flatten_axis < len(data_shape) - ), f"flatten_axis {flatten_axis} is out of bounds for shape {data_shape}" - + is_block_2d = block_x > 1 and block_y > 1 assert data_shape[flatten_axis - 1] % block_x == 0, ( f"Data shape {data_shape} should be divisible by block_x {block_x} in axis" f" {flatten_axis - 1}" @@ -440,6 +559,9 @@ def get_scale_shape( data_shape[-1] % block_y == 0 ), f"Data shape {data_shape} should be divisible by block_y {block_y} in axis -1" + if broadcast_2d_scale_shape_to_1d and is_block_2d: + block_x = 1 + flattened_first_dim = reduce(operator.mul, data_shape[:flatten_axis], 1) flattened_last_dim = reduce(operator.mul, data_shape[flatten_axis:], 1) @@ -562,52 +684,98 @@ def get_grouped_scale_shape( return (n_block_x * n_block_y,) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis + self, + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + is_colwise_transposed, ) -> QuantizeShardyRules: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix - + flatten_axis: Axis along which data can be flattened to 2D for quantization + q_layout: The layout of the quantized tensor + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. + is_colwise_transposed: Whether the column-wise tensors are transposed. Returns: The Shardy rules for the scaling mode """ - del flatten_axis - input_spec = [f"{unique_var}{i}" for i in range(input_rank)] - rowwise = [f"{unique_var}scale_inv_rowwise{i}" for i in range(input_rank)] - colwise = [f"{unique_var}scale_inv_colwise{i}" for i in range(input_rank)] - - # NOTE (Alp): Padding the scales breaks the size relationship in CompoundFactors. - # Unfortunately, because Shardy rules are applied to the inner primitive, the - # only way to preserve the relationship is to lower unpadded scales to the - # underlying custom call and pad them in C++. Until that's implemented, the - # Shardy rules for block scales have to be completely disconnected from the - # Shardy rules for the tensor they belong to. - - # # We have to use two different factors in the two CompoundFactors because of Shardy - # # verifier requirements, even though they are the same. - # rowwise_var = unique_var - # colwise_var = f"{unique_var}_" - # input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "block_size_colwise") - # input_spec[-1] = CompoundFactor(rowwise_var, "block_size_rowwise") + is_rowwise = q_layout.has_rowwise + is_colwise = q_layout.has_colwise - # # The rowwise and colwise scale tensors should be sharded the same way as the input. - # # However, we need to adjust the dimensions where the block scaling factor applies. - # rowwise = input_spec.copy() - # rowwise[-1] = rowwise_var + input_rank = len(input_shape) + flatten_axis = (flatten_axis + input_rank) % input_rank + input_spec = [f"{unique_var}_x_{i}" for i in range(input_rank)] - # colwise = input_spec.copy() - # colwise[flatten_axis - 1] = colwise_var + assert ( + self._block_dims[1] != 1 + ), f"Expect 1D rowwise or 2D block. Got _block_dims={self._block_dims}" + # For 2D block scaling, only support when with broadcast_2d_scale_shape_to_1d + if self._block_dims[0] != 1: + assert self._block_dims[0] == self._block_dims[1] and broadcast_2d_scale_shape_to_1d, ( + f"Got broadcast_2d_scale_shape_to_1d={broadcast_2d_scale_shape_to_1d}," + f" _block_dims={self._block_dims}" + ) + + block_size_1d = self._block_dims[1] + + # We have to use two different factors in the two CompoundFactors because of Shardy + # verifier requirements, even though they are the same. + # No CompoundFactor is needed if the dim has the same size as the blocksize + blocksizes = {} + rowwise_var = f"{unique_var}_None" + colwise_var = f"{unique_var}_None" + if is_rowwise and not input_shape[-1] == block_size_1d: + rowwise_var = input_spec[-1] + "_compound" + input_spec[-1] = CompoundFactor(rowwise_var, "blocksize_x") + blocksizes["blocksize_x"] = block_size_1d + if is_colwise and not input_shape[flatten_axis - 1] == block_size_1d: + colwise_var = input_spec[flatten_axis - 1] + "_compound" + input_spec[flatten_axis - 1] = CompoundFactor(colwise_var, "blocksize_y") + blocksizes["blocksize_y"] = block_size_1d + + # The rowwise and colwise scale tensors should be sharded the same way as the input. + # However, we need to adjust the dimensions where the block scaling factor applies. + if is_rowwise: + rowwise_out = input_spec.copy() + rowwise_scale = input_spec.copy() + rowwise_scale[-1] = rowwise_var + else: + rowwise_out = [ + BATCHING + f"{unique_var}_rowwise_output", + ] + rowwise_scale = [ + BATCHING + f"{unique_var}_rowwise_scale_inv", + ] - # # This implementation needs to be updated for different block dims. - # assert self._block_dims == (1, 32) + if is_colwise: + colwise_out = input_spec.copy() + colwise_scale = input_spec.copy() + colwise_scale[flatten_axis - 1] = colwise_var + if is_colwise_transposed: + from ..cpp_extensions.misc import multidim_transpose + + colwise_out = multidim_transpose(colwise_out, transpose_axis=flatten_axis) + colwise_scale = multidim_transpose(colwise_scale, transpose_axis=flatten_axis) + else: + colwise_out = [ + BATCHING + f"{unique_var}_colwise_output", + ] + colwise_scale = [ + BATCHING + f"{unique_var}_colwise_scale_inv", + ] return QuantizeShardyRules( tuple(input_spec), - tuple(rowwise), - tuple(colwise), - {}, # {"block_size_rowwise": 32, "block_size_colwise": 32}, + tuple(rowwise_out), + tuple(rowwise_scale), + tuple(colwise_out), + tuple(colwise_scale), + blocksizes, ) @@ -620,6 +788,8 @@ class ScalingMode(Enum): - DELAYED_TENSOR_SCALING: Uses delayed scaling with FP8 data type and float32 scales - MXFP8_1D_SCALING: Uses block-based scaling with FP8 data type and E8M0 scales - CURRENT_TENSOR_SCALING: Uses current scaling with FP8 data type and float32 scales + - NVFP4_1D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales + - NVFP4_2D_SCALING: Uses block-based scaling with FP4 data type and E4M3 scales - NO_SCALING: No scaling applied """ @@ -627,6 +797,8 @@ class ScalingMode(Enum): DELAYED_TENSOR_SCALING = JAXX_Scaling_Mode.DELAYED_TENSOR_SCALING MXFP8_1D_SCALING = JAXX_Scaling_Mode.MXFP8_1D_SCALING CURRENT_TENSOR_SCALING = JAXX_Scaling_Mode.CURRENT_TENSOR_SCALING + NVFP4_1D_SCALING = JAXX_Scaling_Mode.NVFP4_1D_SCALING + NVFP4_2D_SCALING = JAXX_Scaling_Mode.NVFP4_2D_SCALING def _get_impl(self) -> ScalingModeMetadataImpl: """Get the implementation for this scaling mode. @@ -650,40 +822,79 @@ def get_scale_dtype(self): """ return self._get_impl().get_scale_dtype() - def get_scale_shape_2x(self, data_shape, is_padded=True, flatten_axis=-1) -> Tuple[Tuple[int]]: + def get_scale_shape_2x( + self, data_shape, is_padded=True, flatten_axis=-1, broadcast_2d_scale_shape_to_1d=False + ) -> Tuple[Tuple[int]]: """Get shapes for both row-wise and column-wise scaling. Args: data_shape: Shape of the data tensor is_padded: Whether to use padded shapes flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: Tuple of (rowwise_scale_shape, colwise_scale_shape) """ + data_layout = self._get_impl().get_data_layout() + rowwise_layout = data_layout[0] + assert ( + rowwise_layout == "N" + ), f"For rowwise layout only 'N' is supported, received {rowwise_layout}" + colwise_layout = data_layout[1] + rowwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=False, is_padded=is_padded, flatten_axis=flatten_axis + data_shape, + data_layout=rowwise_layout, + is_colwise=False, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, ) + + colwise_data_shape = data_shape + if colwise_layout == "T": + colwise_data_shape = data_shape[flatten_axis:] + data_shape[:flatten_axis] colwise_scale_shape = self.get_scale_shape( - data_shape, is_colwise=True, is_padded=is_padded, flatten_axis=flatten_axis + colwise_data_shape, + data_layout=colwise_layout, + is_colwise=True, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, ) return (rowwise_scale_shape, colwise_scale_shape) def get_scale_shape( - self, data_shape, is_colwise, is_padded=True, flatten_axis=-1 + self, + data_shape, + data_layout="N", + is_colwise=False, + is_padded=True, + flatten_axis=-1, + broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[int]: """Get the shape for scale tensors in this mode. Args: data_shape: Shape of the data tensor + data_layout: Layout of the data shape, either "N" (default) or "T" for transposed. is_colwise: Whether to use column-wise scaling is_padded: Whether to use padded shapes flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The shape for scale tensors """ - return self._get_impl().get_scale_shape(data_shape, is_colwise, is_padded, flatten_axis) + return self._get_impl().get_scale_shape( + data_shape, + data_layout=data_layout, + is_colwise=is_colwise, + is_padded=is_padded, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=broadcast_2d_scale_shape_to_1d, + ) def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: """Get the quantize layout for the tensor usage. @@ -697,18 +908,33 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: return self._get_impl().get_quantize_layout(usage) def get_shardy_sharding_rules( - self, input_rank, unique_var, flatten_axis=-1 + self, + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d=False, ) -> Tuple[Tuple[str]]: """Sharding rules for the input and (row, col)wise scale tensors. Args: - input_rank: The rank of the input tensor (for which we produce the scale tensor) + input_shape: The shape of the input tensor (for which we produce the scale tensor) unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + q_layout: The layout of the quantized tensor + broadcast_2d_scale_shape_to_1d: Whether to broadcast the 2D scale shape to 1D. Defaults to False. Returns: The Shardy rules for the scaling mode """ - return self._get_impl().get_shardy_sharding_rules(input_rank, unique_var, flatten_axis) + return self._get_impl().get_shardy_sharding_rules( + input_shape, + unique_var, + flatten_axis, + q_layout, + broadcast_2d_scale_shape_to_1d, + self.is_colwise_transposed, + ) def get_grouped_scale_shape_2x( self, data_shape, n_groups, group_axis, is_padded=True, flatten_axis=-1 @@ -782,8 +1008,64 @@ def is_1d_block_scaling(self) -> bool: Returns: True if the scaling mode is 1D block scaling, False otherwise """ + # Both 1D and 2D NVFP4 scaling are treated as 1D block scaling since the 2D scales are broadcast to 1D because it is required for the GEMM. + return self == ScalingMode.MXFP8_1D_SCALING or self.is_nvfp4_scaling + + @property + def is_block_scaling(self) -> bool: + """Check if this scaling mode is block scaling. + + Returns: + True if the scaling mode is block scaling, False otherwise + """ + # Currently we only have 1D block scaling modes + return self.is_1d_block_scaling() + + def get_compatible_q_dtypes(self) -> set[jnp.dtype]: + """Returns a set of compatible quantized data types for this scaling mode. + + Returns: + A set of compatible quantized data types + """ + if self in ( + ScalingMode.DELAYED_TENSOR_SCALING, + ScalingMode.CURRENT_TENSOR_SCALING, + ScalingMode.MXFP8_1D_SCALING, + ): + return {jnp.float8_e5m2, jnp.float8_e4m3fn} + if self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING): + return {jnp.float4_e2m1fn} + if self == ScalingMode.NO_SCALING: + return {jnp.float16, jnp.bfloat16, jnp.float32} + raise ValueError(f"Invalid scaling mode: {self}") + + @property + def is_nvfp4_scaling(self) -> bool: + """Check if this scaling mode is NVFP4 scaling. + + Returns: + True if the scaling mode is NVFP4 scaling, False otherwise + """ + return self in (ScalingMode.NVFP4_1D_SCALING, ScalingMode.NVFP4_2D_SCALING) + + @property + def is_mxfp8_scaling(self) -> bool: + """Check if this scaling mode is NVFP4 scaling. + + Returns: + True if the scaling mode is NVFP4 scaling, False otherwise + """ return self == ScalingMode.MXFP8_1D_SCALING + @property + def is_colwise_transposed(self) -> bool: + """Check if this scaling mode uses transposed layout for column-wise scaling. + + Returns: + True if the scaling mode uses transposed layout for column-wise scaling, False otherwise + """ + return self.is_tensor_scaling() or self.is_nvfp4_scaling + def __eq__(self, other): """Compare this scaling mode with another. @@ -820,9 +1102,20 @@ def tree_unflatten(cls, aux_data, _children): SCALING_MODES_TO_IMPL: Dict[ScalingMode, ScalingModeMetadataImpl] = { + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), ScalingMode.DELAYED_TENSOR_SCALING: DelayedScalingModeMetadataImpl(), - ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), - # WAR + ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(1, 32), + scale_dtype=jnp.float8_e8m0fnu, + data_layout="NN", + ), ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), + ScalingMode.NVFP4_1D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(1, 16), + scale_dtype=jnp.float8_e4m3fn, + data_layout="NT", + ), + ScalingMode.NVFP4_2D_SCALING: BlockScalingModeMetadataImpl( + block_dims=(16, 16), scale_dtype=jnp.float8_e4m3fn, data_layout="NT" + ), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index dbbac4abc..90f139c3d 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -14,11 +14,12 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class +from jax.ad_checkpoint import checkpoint_name as jax_checkpoint_name -from transformer_engine_jax import QuantizeLayout from .scaling_modes import ScalingMode, TensorUsage from .dequantizer import ScalingModeToDequantizerMap +from .misc import QuantizeLayout from ..sharding import ( with_sharding_constraint_by_logical_axes as original_with_sharding_constraint_by_logical_axes, ) @@ -89,6 +90,17 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st The tensor with applied sharding constraints """ + @abstractmethod + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + @dataclass class AbstractBaseTensor1x(AbstractBaseTensor): @@ -128,9 +140,7 @@ def dequantize(self): def get_tensor(self, usage: TensorUsage): """Returns the tensor based on the tensor usage.""" q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) - assert ( - q_layout == QuantizeLayout.ROWWISE - ), "Only ROWWISE layout is supported for NoScaleTensor" + assert q_layout.is_rowwise_only, "Only ROWWISE layout is supported for NoScaleTensor" return self def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): @@ -152,6 +162,18 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st amax=self.amax, ) + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + assert quantizer is None, "NoScaleTensor does not support quantization." + return self + class ScaledTensor(ABC): """Abstract base class for scaled tensors.""" @@ -175,6 +197,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): is_colwise: Whether the tensor uses column-wise quantization data_layout: The data_layout specification for the tensor flatten_axis: The quantization axis for the tensor + has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization """ scale_inv: jnp.ndarray @@ -184,6 +207,7 @@ class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): is_colwise: bool data_layout: str flatten_axis: int + has_rht_applied: bool def __post_init__(self): """Validates and adjusts the scale_inv shape after initialization. @@ -201,13 +225,32 @@ def __post_init__(self): else: unpadded_scale_shape = self.scaling_mode.get_scale_shape( self.data.shape, + data_layout=self.data_layout, is_colwise=self.is_colwise, is_padded=False, - flatten_axis=self.flatten_axis, + # expect the flatten_axis wrt the N layout + flatten_axis=( + self.flatten_axis + if self.data_layout == "N" + else self.data.ndim - self.flatten_axis + ), ) - assert self.scale_inv.shape == unpadded_scale_shape, ( - "Unpadded inverse scale factor has wrong shape, expected" - f" {unpadded_scale_shape} but got {self.scale_inv.shape}." + unpadded_scale_shape_broadcast = self.scaling_mode.get_scale_shape( + self.data.shape, + data_layout=self.data_layout, + is_colwise=self.is_colwise, + is_padded=False, + # expect the flatten_axis wrt the N layout + flatten_axis=( + self.flatten_axis + if self.data_layout == "N" + else self.data.ndim - self.flatten_axis + ), + broadcast_2d_scale_shape_to_1d=True, + ) + assert self.scale_inv.shape in (unpadded_scale_shape, unpadded_scale_shape_broadcast), ( + f"Unpadded inverse scale factor has wrong shape, expected {unpadded_scale_shape} or" + f" {unpadded_scale_shape_broadcast} but got {self.scale_inv.shape}." ) def tree_flatten(self): @@ -224,6 +267,7 @@ def tree_flatten(self): self.is_colwise, self.data_layout, self.flatten_axis, + self.has_rht_applied, ) return (children, aux_data) @@ -242,8 +286,8 @@ def dequantize(self): def get_tensor(self, usage: TensorUsage): """Returns the tensor based on the tensor usage.""" q_layout = self.scaling_mode.get_quantize_layout(usage) - colwise_usage_valid = q_layout == QuantizeLayout.COLWISE and self.is_colwise - rowwise_usage_valid = q_layout == QuantizeLayout.ROWWISE and not self.is_colwise + colwise_usage_valid = q_layout.is_colwise_only and self.is_colwise + rowwise_usage_valid = q_layout.is_rowwise_only and not self.is_colwise if colwise_usage_valid or rowwise_usage_valid: return self @@ -279,24 +323,38 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st data = with_sharding_constraint_by_logical_axes(self.data, axis_names) - if self.scaling_mode == ScalingMode.MXFP8_1D_SCALING: - # TODO(Phuong): Handle padding !? + if self.scaling_mode.is_block_scaling: # Both MXFP8 and NVFP4 scale_inv = with_sharding_constraint_by_logical_axes(self.scale_inv, axis_names) else: scale_inv = self.scale_inv return ScaledTensor1x( data=data, - scale_inv=scale_inv, amax=self.amax, + scale_inv=scale_inv, scaling_mode=self.scaling_mode, dq_dtype=self.dq_dtype, _dq_func=self._dq_func, is_colwise=self.is_colwise, data_layout=self.data_layout, flatten_axis=self.flatten_axis, + has_rht_applied=self.has_rht_applied, ) + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + if quantizer is None or quantizer.checkpoint_name is None: + return self + + return jax_checkpoint_name(self, name=quantizer.checkpoint_name) + @register_pytree_node_class @dataclass @@ -335,6 +393,7 @@ def __init__( self.group_sizes = group_sizes self.original_shape = original_shape self.group_axis = group_axis + # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 super().__init__( data=data, scale_inv=scale_inv, @@ -345,6 +404,7 @@ def __init__( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, + has_rht_applied=False, ) def __post_init__(self): @@ -398,6 +458,20 @@ def tree_flatten(self): def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): raise NotImplementedError + def checkpoint(self, quantizer): + """Checkpoints the tensor with the given quantizer's checkpoint name if available. + + Args: + quantizer: The quantizer to use for checkpointing. If None, no checkpointing is applied. + + Returns: + The checkpointed tensor + """ + if quantizer is None or quantizer.checkpoint_name is None: + return self + + return jax_checkpoint_name(self, name=quantizer.checkpoint_name) + @register_pytree_node_class @dataclass @@ -442,10 +516,10 @@ def get_tensor(self, usage: TensorUsage): q_layout_rowwise = self.rowwise_tensor.scaling_mode.get_quantize_layout(usage) q_layout_colwise = self.colwise_tensor.scaling_mode.get_quantize_layout(usage) - if q_layout_rowwise == QuantizeLayout.ROWWISE: + if q_layout_rowwise.is_rowwise_only: return self.rowwise_tensor - if q_layout_colwise == QuantizeLayout.COLWISE: + if q_layout_colwise.is_colwise_only: return self.colwise_tensor raise ValueError( @@ -474,6 +548,9 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return ScaledTensor2x(rowwise_tensor, colwise_tensor) + def checkpoint(self, quantizer): + raise NotImplementedError + @dataclass class ScaledTensorFactory: @@ -496,6 +573,7 @@ def create_1x( group_sizes=None, original_shape=None, group_axis=0, + has_rht_applied=False, ): """Creates a single-scale quantized tensor. @@ -511,6 +589,7 @@ def create_1x( group_sizes: Array of ints containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) + has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided @@ -521,13 +600,13 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) if group_sizes is not None: - flatten_axis = len(original_shape) + flatten_axis if flatten_axis < 0 else flatten_axis + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" # Handling attrs of transposed tensors - group_axis = len(original_shape) + group_axis if group_axis < 0 else group_axis + group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": if original_shape[0] == group_sizes.size: original_shape = ( @@ -560,7 +639,7 @@ def create_1x( ) # Handling attrs of transposed tensors - flatten_axis = data.ndim + flatten_axis if flatten_axis < 0 else flatten_axis + flatten_axis = (data.ndim + flatten_axis) % data.ndim if data_layout == "T": flatten_axis = data.ndim - flatten_axis @@ -574,6 +653,7 @@ def create_1x( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, + has_rht_applied=has_rht_applied, ) @staticmethod @@ -583,6 +663,7 @@ def create_2x( colwise_data, colwise_scale_inv, amax=None, + colwise_amax=None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, data_layout="NN", @@ -590,6 +671,8 @@ def create_2x( group_sizes=None, original_shape=None, group_axis=0, + rowwise_has_rht_applied=False, + colwise_has_rht_applied=False, ): """Creates a double-scale quantized tensor. @@ -606,12 +689,16 @@ def create_2x( group_sizes: Array containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) + rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + colwise_has_rht_applied: Whether the column-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) Returns: A ScaledTensor2x instance """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) + if colwise_amax is None: + colwise_amax = amax assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" rowwise_tensor = ScaledTensorFactory.create_1x( @@ -626,11 +713,12 @@ def create_2x( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=rowwise_has_rht_applied, ) colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, - amax, + colwise_amax, scaling_mode, dq_dtype, is_colwise=True, @@ -639,6 +727,7 @@ def create_2x( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=colwise_has_rht_applied, ) return ScaledTensor2x(rowwise_tensor, colwise_tensor) @@ -649,6 +738,7 @@ def create( colwise_data: jnp.ndarray, colwise_scale_inv: jnp.ndarray, amax=None, + colwise_amax=None, scaling_mode: ScalingMode = ScalingMode.NO_SCALING, dq_dtype: jnp.dtype = jnp.bfloat16, data_layout: str = "NN", @@ -657,6 +747,8 @@ def create( group_sizes: jnp.ndarray = None, original_shape: Tuple[int] = None, group_axis: int = 0, + rowwise_has_rht_applied: bool = False, + colwise_has_rht_applied: bool = False, ): """Creates a scaled tensor based on the quantization axis. @@ -673,17 +765,22 @@ def create( group_sizes: Array containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) + rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) + colwise_has_rht_applied: Whether the col-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) Returns: Either a ScaledTensor1x or ScaledTensor2x instance depending on q_layout """ - if q_layout == QuantizeLayout.ROWWISE_COLWISE: + assert not rowwise_has_rht_applied, "RHT is not supported for rowwise quantization yet" + + if q_layout.is_rowwise_colwise: return ScaledTensorFactory.create_2x( data, scale_inv, colwise_data, colwise_scale_inv, amax, + colwise_amax, scaling_mode, dq_dtype, data_layout=data_layout, @@ -691,22 +788,24 @@ def create( group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + rowwise_has_rht_applied=rowwise_has_rht_applied, + colwise_has_rht_applied=colwise_has_rht_applied, ) - is_colwise = q_layout == QuantizeLayout.COLWISE - if is_colwise: + if q_layout.is_colwise_only: return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, - amax, + colwise_amax if colwise_amax is not None else amax, scaling_mode, dq_dtype, - is_colwise=is_colwise, + is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=colwise_has_rht_applied, ) return ScaledTensorFactory.create_1x( @@ -715,12 +814,13 @@ def create( amax, scaling_mode, dq_dtype, - is_colwise=is_colwise, + is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, + has_rht_applied=rowwise_has_rht_applied, ) diff --git a/transformer_engine/jax/setup.py b/transformer_engine/jax/setup.py index b58d2df7f..619b6070b 100644 --- a/transformer_engine/jax/setup.py +++ b/transformer_engine/jax/setup.py @@ -46,8 +46,8 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import ( rocm_build, copy_common_headers, copy_hipify_tools, - clear_hipify_tools_copy) +from build_tools.utils import copy_common_headers, min_python_version_str +from build_tools.utils import rocm_build, copy_hipify_tools, clear_hipify_tools_copy from build_tools.te_version import te_version from build_tools.jax import setup_jax_extension, install_requirements, test_requirements @@ -57,6 +57,26 @@ CMakeBuildExtension = get_build_ext(BuildExtension, True) +def get_cuda_major_version() -> int: + """Get CUDA major version using Jax backend.""" + + assert ( + jax._src.lib.cuda_versions is not None + ), "GPU backend is required to build TE jax extensions." + + # Jax currently does not have any stable/public method to get cuda version. + # Try using internal function and default to cuda12 if not found. + try: + cuda_version = jax._src.lib.cuda_versions.cuda_runtime_get_version() + cuda_major_version = cuda_version // 1000 + except AttributeError: + cuda_version = os.getenv("CUDA_VERSION", "12") + cuda_major_version = int(cuda_version.split(".")[0]) + + assert cuda_major_version in (12, 13), f"Unsupported cuda version {cuda_version}." + return cuda_major_version + + if __name__ == "__main__": """Main entry point for JAX extension installation. @@ -97,14 +117,26 @@ ) ] + # Setup version and requirements. + # Having the framework extension depend on the core lib allows + # us to detect CUDA version dynamically during compilation and + # choose the correct wheel for te core lib. + __version__ = te_version() + if not rocm_build(): + te_core = f"transformer_engine_cu{get_cuda_major_version()}=={__version__}" + else: + te_core = f"transformer_engine_rocm=={__version__}" + install_requires = install_requirements() + [te_core] + # Configure package setuptools.setup( name="transformer_engine_jax", - version=te_version(), + version=__version__, description="Transformer acceleration library - Jax Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension}, - install_requires=install_requirements(), + python_requires=f">={min_python_version_str()}", + install_requires=install_requires, tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 339e74e2f..7f204e768 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from typing import Callable, Optional import warnings + import jax import jax.numpy as jnp from jax.interpreters import pxla @@ -74,6 +75,16 @@ def get_sharding_map_logic_axis_to_mesh_axis(): """ Generate a dict to map logical axes to mesh axes. """ + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + if mesh is None or mesh.empty: + # If no mesh is defined, return an empty dict and do not require a MeshResource context to be present + return {} + + abstract_mesh = get_abstract_mesh() + if sorted(abstract_mesh.manual_axes) == sorted(mesh.axis_names): + # If all mesh axes are manual axes, return an empty dict and do not require a MeshResource context to be present + return {} + gsr = global_mesh_resource() is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 @@ -130,7 +141,22 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): # We want to exclude the axes that already used by shard_map and shard_map # only sets those in the abstract_mesh, not the physical one manual_axis_names = get_abstract_mesh().manual_axes - cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec) + + # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too + def filter_manual_axes(name_or_tuple): + if isinstance(name_or_tuple, tuple): + out = tuple(n for n in name_or_tuple if n not in manual_axis_names) + if len(out) == 0: + return None + return out + if name_or_tuple in manual_axis_names: + return None + return name_or_tuple + + cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec) + + if cleaned_axis_names == (None,) * len(cleaned_axis_names): + return x cleaned_pspec = PartitionSpec(*cleaned_axis_names) return jax.lax.with_sharding_constraint(x, cleaned_pspec) @@ -222,6 +248,19 @@ def num_of_devices(): return len(jax.devices()) +def get_num_devices_in_mesh(mesh=None): + """ + Get the number of devices in the given mesh. + If the mesh is None, it would be replaced + by the global mesh. + """ + if mesh is None: + mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh + if mesh.empty: + return 1 + return np.prod(list(mesh.shape.values())) + + def get_mesh_axis_size(axis, mesh=None): """ Get the axis size of the given mesh. @@ -349,6 +388,21 @@ def all_reduce_sum_along_dp_fsdp(x: jnp.array, mesh: jax.sharding.Mesh): return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) +def all_reduce_sum_along_dp_fsdp_tpsp(x: jnp.array, mesh: jax.sharding.Mesh): + """Perform all-reduce sum operation along data parallelism and sequence parallelism axes. + + Args: + x: Input tensor to reduce + mesh: JAX mesh for distributed computation + + Returns: + Reduced tensor + """ + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().tpsp_resource, mesh) + x = lax_paral_op(x, jax.lax.psum, global_mesh_resource().dp_resource, mesh) + return lax_paral_op(x, jax.lax.psum, global_mesh_resource().fsdp_resource, mesh) + + def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mesh): """Perform all-reduce max operation along all axes except pipeline parallelism. @@ -364,3 +418,21 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x + + +def tpsp_axis_size(): + """ + Get the size of the tensor parallelism axis. + Return 1 if no TP axis is set. + """ + return get_mesh_axis_size(global_mesh_resource().tpsp_resource) + + +def dp_or_fsdp_axis_size(): + """ + Get the size of the data parallelism or FSDP axis. + Return 1 if no DP/FSDP axis is set. + """ + dp_size = get_mesh_axis_size(global_mesh_resource().dp_resource) + fsdp_size = get_mesh_axis_size(global_mesh_resource().fsdp_resource) + return dp_size if dp_size > 1 else fsdp_size diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 3bdbe4089..9d894a389 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -46,8 +46,18 @@ def torch_version() -> tuple[int, ...]: moe_sort_chunks_by_index, moe_sort_chunks_by_index_with_probs, ) -from transformer_engine.pytorch.fp8 import fp8_autocast -from transformer_engine.pytorch.fp8 import fp8_model_init +from transformer_engine.pytorch.quantization import fp8_autocast +from transformer_engine.pytorch.quantization import fp8_model_init +from transformer_engine.pytorch.quantization import autocast +from transformer_engine.pytorch.quantization import quantized_model_init +from transformer_engine.pytorch.quantization import is_fp8_available +from transformer_engine.pytorch.quantization import is_mxfp8_available +from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available +from transformer_engine.pytorch.quantization import is_nvfp4_available +from transformer_engine.pytorch.quantization import get_default_recipe +from transformer_engine.pytorch.utils import get_cudnn_version +from transformer_engine.pytorch.utils import get_device_compute_capability +from transformer_engine.pytorch.utils import is_bf16_available from transformer_engine.pytorch.graph import make_graphed_callables from transformer_engine.pytorch.distributed import checkpoint from transformer_engine.pytorch.distributed import CudaRNGStatesTracker @@ -56,6 +66,24 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch import optimizers from transformer_engine.pytorch.export import onnx_export from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import Quantizer +from transformer_engine.pytorch.quantized_tensor import prepare_for_saving +from transformer_engine.pytorch.quantized_tensor import restore_from_saved +from transformer_engine.pytorch.tensor import Float8Quantizer +from transformer_engine.pytorch.tensor import Float8CurrentScalingQuantizer +from transformer_engine.pytorch.tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor import NVFP4Quantizer +from transformer_engine.pytorch.tensor import Float8TensorStorage +from transformer_engine.pytorch.tensor import MXFP8TensorStorage +from transformer_engine.pytorch.tensor import Float8BlockwiseQTensorStorage +from transformer_engine.pytorch.tensor import NVFP4TensorStorage +from transformer_engine.pytorch.tensor import Float8Tensor +from transformer_engine.pytorch.tensor import MXFP8Tensor +from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor +from transformer_engine.pytorch.tensor import NVFP4Tensor try: torch._dynamo.config.error_on_nested_jit_trace = False diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index a823379f1..038ebc3c0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -16,21 +16,23 @@ import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION - +import torch.nn.functional as F import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - SplitAlongDim, get_device_compute_capability, - combine_tensors, split_tensor_along_dim, ) -from transformer_engine.pytorch.utils import attention_mask_func -from transformer_engine.pytorch.tensor.quantized_tensor import ( - QuantizedTensor, +from transformer_engine.pytorch.utils import attention_mask_func, nvtx_range_push, nvtx_range_pop +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensorStorage, prepare_for_saving, restore_from_saved, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.constants import ( TE_DType, QKVLayouts, @@ -43,7 +45,7 @@ META_O, META_QKV, ) -from transformer_engine.pytorch.fp8 import get_fp8_torch_dtype +from transformer_engine.pytorch.quantization import get_fp8_torch_dtype, FP8GlobalStateManager from transformer_engine.pytorch.distributed import get_distributed_world_size from transformer_engine.pytorch.jit import no_torch_dynamo from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( @@ -51,17 +53,30 @@ ) from transformer_engine.pytorch.attention.dot_product_attention.softmax import FusedScaleMaskSoftmax from transformer_engine.pytorch.attention.inference import InferenceParams +from transformer_engine.pytorch.cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_activation_offload, + NVTE_CPU_OFFLOAD_V1, +) +from transformer_engine.pytorch.cpu_offload_v1 import is_current_layer_offloaded # Import attention utils import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, + ConvertTHDtoBSHD, + ConvertBSHDtoTHD, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import ( AttentionLogging as attn_log, ) from transformer_engine.pytorch import export from transformer_engine.pytorch.export import is_in_onnx_export_mode +from transformer_engine.pytorch.graph import is_graph_capturing # Global vars for flash attn v2 and v3 imports flash_attn_cuda_bwd = None @@ -136,6 +151,58 @@ fa_utils.set_flash_attention_3_params() +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" + + +class FP8EmulationFunc(torch.autograd.Function): + """ + Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: + - forward : QKV (quantize+dequantize), P (pass-through), S (quantize+dequantize), O (pass-through) + - backward: dO (quantize+dequantize), dS (pass-through), dP (quantize+dequantize), dQKV (pass-through) + """ + + @staticmethod + def forward(ctx, tensor1, tensor2, tensor3, quantizer, quantizer_name, qkv_layout): + # pylint: disable=missing-function-docstring + if quantizer_name == "QKV_quantizer": + query_layer, key_layer, value_layer = [ + x.contiguous() for x in [tensor1, tensor2, tensor3] + ] + q_fp8, k_fp8, v_fp8 = combine_and_quantize( + qkv_layout, query_layer, key_layer, value_layer, quantizer + ) + tensors = combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=query_layer.dtype + ) + elif quantizer_name in ["S_quantizer", "O_quantizer"]: + t_fp8 = quantizer(tensor1) + tensors = (t_fp8.dequantize(dtype=tensor1.dtype), tensor2, tensor3) + else: + tensors = (tensor1, tensor2, tensor3) + ctx.quantizer = quantizer + ctx.quantizer_name = quantizer_name + ctx.qkv_layout = qkv_layout + return tensors[0], tensors[1], tensors[2] + + @staticmethod + def backward(ctx, grad1, grad2, grad3): + # pylint: disable=missing-function-docstring + if ctx.quantizer_name in ["dO_quantizer", "dP_quantizer"]: + dt_fp8 = ctx.quantizer(grad1) + tensors = dt_fp8.dequantize(dtype=grad1.dtype), grad2, grad3 + elif ctx.quantizer_name == "dQKV_quantizer": + query_grad, key_grad, value_grad = [x.contiguous() for x in [grad1, grad2, grad3]] + dq_fp8, dk_fp8, dv_fp8 = combine_and_quantize( + ctx.qkv_layout, query_grad, key_grad, value_grad, ctx.quantizer + ) + tensors = combine_and_dequantize( + ctx.qkv_layout, dq_fp8, dk_fp8, dv_fp8, src_nominal_dtype=query_grad.dtype + ) + else: + tensors = grad1, grad2, grad3 + return tensors[0], tensors[1], tensors[2], None, None, None + class UnfusedDotProductAttention(torch.nn.Module): """Parallel attention w/o QKV and Proj Gemms @@ -149,13 +216,19 @@ def __init__( attention_dropout: float = 0.0, attention_dropout_ctx: Optional[Callable] = nullcontext, layer_number: Optional[int] = None, + softmax_type: str = "vanilla", + return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.softmax_scale = softmax_scale self.attention_type = attention_type self.attention_dropout_ctx = attention_dropout_ctx self.layer_number = layer_number + self.softmax_type = softmax_type + self.return_max_logit = return_max_logit def mask_func(x, y): return ( @@ -164,6 +237,7 @@ def mask_func(x, y): else attention_mask_func(x, y) ) + self.mask_func = mask_func self.scale_mask_softmax = FusedScaleMaskSoftmax(mask_func) # Dropout. Note that for a single iteration, this layer will generate @@ -185,6 +259,8 @@ def forward( qkv_layout: str = "sbh3d", cu_seqlens_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument cu_seqlens_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + max_seqlen_q: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + max_seqlen_kv: Optional[torch.Tensor] = None, # pylint: disable=unused-argument attn_mask_type: str = "causal", attention_mask: Optional[Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]] = None, window_size: Optional[Tuple[int, int]] = None, @@ -192,6 +268,11 @@ def forward( core_attention_bias: Optional[torch.Tensor] = None, alibi_slopes: Optional[torch.Tensor] = None, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + quantizers=None, + fp8_output: bool = False, ) -> torch.Tensor: """Unfused attention fprop""" assert ( @@ -203,6 +284,9 @@ def forward( if inference_params is not None and inference_params.is_paged: key_layer, value_layer = inference_params.convert_paged_to_nonpaged(self.layer_number) + # convert to sbhd + # training: bshd, thd + # inference: bshd, sbhd_2bshd, thd_2bshd if qkv_format == "bshd": # convert to sbhd and use sbhd implementation for now query_layer, key_layer, value_layer = [ @@ -211,9 +295,8 @@ def forward( if qkv_format == "sbhd_2bshd": key_layer, value_layer = [x.transpose(0, 1) for x in [key_layer, value_layer]] - total_tokens, batch_size = None, None if qkv_format == "thd_2bshd": - total_tokens, batch_size = query_layer.shape[0], key_layer.shape[0] + batch_size = key_layer.shape[0] query_layer = tex.convert_thd_to_bshd( query_layer, cu_seqlens_q, @@ -223,6 +306,26 @@ def forward( query_layer, key_layer, value_layer = [ x.transpose(0, 1) for x in [query_layer, key_layer, value_layer] ] + if qkv_format == "thd": + assert cu_seqlens_q is not None and cu_seqlens_kv is not None + assert max_seqlen_q is not None and max_seqlen_kv is not None + query_layer = ConvertTHDtoBSHD.apply( + query_layer, + cu_seqlens_q, + max_seqlen_q, + ) + key_layer, value_layer = [ + ConvertTHDtoBSHD.apply( + x, + cu_seqlens_kv, + max_seqlen_kv, + ) + for x in [key_layer, value_layer] + ] + query_layer, key_layer, value_layer = [ + x.transpose(0, 1).contiguous() for x in [query_layer, key_layer, value_layer] + ] + batch_size, max_seqlen_q, max_seqlen_kv = ( query_layer.shape[1], query_layer.shape[0], @@ -289,6 +392,35 @@ def forward( if apply_qk_layer_scaling: scale /= self.layer_number + if fp8: + # get quantizers from DPA; all Nones if not fp8 + QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( + dpa_utils.get_attention_quantizers(fp8, quantizers) + ) + # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if fp8_recipe.float8_current_scaling(): + S_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=S_quantizer.dtype, device="cuda" + ) + dP_quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=dP_quantizer.dtype, device="cuda" + ) + + if "2" in qkv_layout or "3" in qkv_layout: + qkv_format, *_ = dpa_utils.get_qkv_format(qkv_layout) + qkv_layout = "_".join([qkv_format] * 3) + # quantize and dequantize QKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, QKV_quantizer, "QKV_quantizer", qkv_layout + ) + # quantize and dequantize dQKV to emulate FP8 + query_layer, key_layer, value_layer = FP8EmulationFunc.apply( + query_layer, key_layer, value_layer, dQKV_quantizer, "dQKV_quantizer", qkv_layout + ) + # Raw attention scores. [b * np, sq, sk] if core_attention_bias_type == "no_bias": matmul_result = torch.baddbmm( @@ -333,7 +465,36 @@ def forward( dtype=query_layer.dtype ) - # attention scores and attention mask [b, np, sq, sk] + if fp8: + # quantize and dequantize dP to emulate FP8 + matmul_result, *_ = FP8EmulationFunc.apply( + matmul_result, None, None, dP_quantizer, "dP_quantizer", None + ) + + # max attention score + max_logit = None + if self.return_max_logit: + # matmul_result [b, np, sq, dk], max_logit [np] + max_logit = matmul_result + if attn_mask_type != "no_mask": + max_logit = self.mask_func(matmul_result, attention_mask) + max_logit = torch.amax(max_logit, dim=(0, 2, 3)) + + # add attention sink to the last column: [b, np, sq, sk+1] + if self.softmax_type != "vanilla": + matmul_result = torch.cat( + [ + matmul_result, + softmax_offset.to(dtype=matmul_result.dtype).expand( + matmul_result.size(0), -1, matmul_result.size(2), -1 + ), + ], + dim=-1, + ) + attention_mask = F.pad(attention_mask, (0, 1), mode="constant", value=False) + attn_mask_type = "arbitrary" + + # attention scores and attention mask softmax_scale = self.layer_number if apply_qk_layer_scaling else None attention_probs = self.scale_mask_softmax( matmul_result, attention_mask, attn_mask_type, softmax_scale @@ -344,6 +505,10 @@ def forward( if "padding" in attn_mask_type: attention_probs = attention_probs.masked_fill(attention_mask, 0) + # remove attention sink: [b, np, sq, sk] + if self.softmax_type != "vanilla": + attention_probs = attention_probs[..., :-1] + # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. with self.attention_dropout_ctx(): @@ -364,6 +529,12 @@ def forward( # change view [b * np, sq, sk] attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) + if fp8: + # quantize and dequantize S to emulate FP8 + attention_probs, *_ = FP8EmulationFunc.apply( + attention_probs, None, None, S_quantizer, "S_quantizer", None + ) + # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) @@ -389,14 +560,30 @@ def forward( context_layer = context_layer.permute(0, 2, 1, 3).contiguous() # [b, sq, np, hn] --> [tq, np, hn] - context_layer = tex.convert_bshd_to_thd( + context_layer = ConvertBSHDtoTHD.apply( context_layer, cu_seqlens_q, - total_tokens, ) # [tq, np, hn] --> [tq, hp] - context_layer = context_layer.view(total_tokens, -1) + context_layer = context_layer.view(context_layer.shape[0], -1) + + if fp8: + # quantize and dequantize O to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, O_quantizer, "O_quantizer", None + ) + # quantize and dequantize dO to emulate FP8 + context_layer, *_ = FP8EmulationFunc.apply( + context_layer, None, None, dO_quantizer, "dO_quantizer", None + ) + + # quantize O + if fp8_output: + context_layer = O_quantizer(context_layer) + + if self.return_max_logit: + return context_layer, max_logit return context_layer @@ -496,6 +683,7 @@ def forward( quantizers=None, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), + fp8_output: bool = False, ) -> torch.Tensor: """flash-attn fprop""" @@ -564,6 +752,9 @@ def forward( x.contiguous() for x in (query_layer._data, key_layer._data, value_layer._data) ] + if is_cpu_offload_enabled(): + start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True) + # get batch_size, max_seqlen and cu_seqlens batch_size, context_len = None, None if inference_params is None: @@ -701,14 +892,10 @@ def forward( quantizers=quantizers, pad_between_seqs=False, use_flash_attn_3=use_flash_attn_3, + fp8_output=fp8_output, ) else: - from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadEnabled, - mark_activation_offload, - ) - - if CPUOffloadEnabled: + if is_cpu_offload_enabled(): mark_activation_offload( query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv ) @@ -800,8 +987,6 @@ def convert_to_torch_float8(tensor, dtype): ) return out - # "fp8_mha" decides outputs in fp8, while inputs are inferred from - # the real dtype assert isinstance(key_layer, query_layer.__class__) and isinstance( value_layer, query_layer.__class__ ), "q, k, and v must have the same type." @@ -848,7 +1033,7 @@ def convert_to_torch_float8(tensor, dtype): if fp8: output = output.to(dtype=torch_orig_dtype) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: O_quantizer = quantizers["scaling_fwd"][META_O] output = O_quantizer(output) @@ -876,7 +1061,7 @@ def convert_to_torch_float8(tensor, dtype): if q_format == "sbhd": # (bs)hd -> bs(hd) -> sb(hd) - if fp8 and fp8_meta["recipe"].fp8_mha: + if fp8 and fp8_output: output_data = ( output._data.reshape(batch_size, max_seqlen_q // cp_size, -1) .transpose(0, 1) @@ -900,7 +1085,7 @@ def convert_to_torch_float8(tensor, dtype): class FusedAttnFunc(torch.autograd.Function): - """Function for FusedAttention with separate Q, K, V tensors""" + """FusedAttention forward and backward implementation""" @staticmethod def forward( @@ -924,6 +1109,7 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, fused_attention_backend, @@ -932,55 +1118,77 @@ def forward( fp8_meta, quantizers, deterministic, + softmax_offset, + fp8_output, + layer_number, + return_max_logit, ): # pylint: disable=missing-function-docstring - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = fp8_meta["recipe"].fp8_mha if "recipe" in fp8_meta else False - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e4m3fn - fake_dtype = q.dtype + # add NVTX range + nvtx_label = "transformer_engine.FusedAttnFunc.forward" + nvtx_range_push(f"{nvtx_label}") + + if is_cpu_offload_enabled(): + start_offload(q, k, v, offload_base_tensor=True) + + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + + # input types are inferred from the real data while output types are controlled by fp8_output + # fp8_output should be set upstream as (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_mha) + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + # whether fwd kernel in FP8: fp8 = (DPA.fp8 and DPA.fp8_meta["recipe"].fp8_dpa) + # whether bwd kernel in FP8: + is_bwd_fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + + # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + # get nominal data type for out + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + out_nominal_dtype = q.dtype + + max_logit = None if fp8: fused_attention_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - q_fp8, k_fp8, v_fp8 = None, None, None + # q, k, v: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # q_fp8, k_fp8, v_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 if is_input_fp8: q_fp8, k_fp8, v_fp8 = q, k, v else: - # 1: qkv packed, 2: kv packed, 3: qkv separate - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - match qkv_group: - case 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_fp8 = QKV_quantizer(qkv) - q_fp8, k_fp8, v_fp8 = SplitAlongDim.apply(qkv_fp8, dim, [1, 1, 1], True) - case 2: - q_fp8 = QKV_quantizer(q) - dim = qkv_layout.split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_fp8 = QKV_quantizer(kv_c) - k_fp8, v_fp8 = SplitAlongDim.apply(kv_fp8, dim, [1, 1], True) - case 3: - q_fp8 = QKV_quantizer(q) - k_fp8 = QKV_quantizer(k) - v_fp8 = QKV_quantizer(v) - case _: - raise "Invalid qkv_layout " + qkv_layout - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - out_fp8, aux_ctx_tensors = fused_attn_fwd( + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors, *_ = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -989,7 +1197,7 @@ def forward( q_fp8, k_fp8, v_fp8, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1004,45 +1212,60 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, + cuda_graph=is_graph_capturing(), ) - if is_output_fp8: - out_ret = out_fp8 + + # out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_fp8 = out_ + out = out_ + + if isinstance(out_, Float8Tensor): + if not is_output_fp8 or not is_bwd_fp8: + out = out_.dequantize().view(out_.shape) else: - out_ret = out_fp8.dequantize().view(out_fp8.shape) - # is_output_fp8 = False: out_save.dtype = torch.float16 or torch.bfloat16 - # is_output_fp8 = True: out_save.dtype = torch.float8_e4m3fn - out_save = out_ret + if is_output_fp8 or ( + is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_fp8 = O_quantizer(out_) + + # print quantizers + print_quantizers( + "FusedAttnFunc.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # return appropriate tensors + out_ret = out_fp8 if is_output_fp8 else out - if not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - # 1: qkv packed, 2: kv packed, 3: qkv separate + # save appropriate tensors + fp8_tensors = (None, None, None, None) + qkvo_tensors = (None, None, None, None) + if is_bwd_fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_fp8, k_fp8, v_fp8, None) + qkvo_tensors = (None, None, None, out) + else: + fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + else: if is_input_fp8: - qkv_group = len(qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = qkv_layout.find("3") - qkv = combine_tensors([q, k, v], dim) - qkv_c = qkv.view(-1, qkv.shape[-3] * qkv.shape[-2] * qkv.shape[-1]) - qkv_no_fp8 = qkv_c.dequantize().view(qkv.shape) - q, k, v = SplitAlongDim.apply(qkv_no_fp8, dim, [1, 1, 1], True) - if qkv_group == 2: - q = q.dequantize() - dim = qkv_layout.replace("paged_kv_", "").split("_")[1].find("2") - kv = combine_tensors([k, v], dim) - kv_c = kv.view(-1, kv.shape[-3] * kv.shape[-2] * kv.shape[-1]) - kv_no_fp8 = kv.dequantize() - k, v = SplitAlongDim.apply(kv_no_fp8, dim, [1, 1], True) - if qkv_group == 3: - q = q.dequantize() - k = k.dequantize() - v = v.dequantize() - if is_output_fp8: - out_save = out_fp8.dequantize() - - fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) + q, k, v = combine_and_dequantize(qkv_layout, q_fp8, k_fp8, v_fp8) + qkvo_tensors = (q, k, v, out) else: - # q, k, v, out_ret: torch.float16 or torch.bfloat16 - out_ret, aux_ctx_tensors = fused_attn_fwd( + # q, k, v, out_: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -1051,7 +1274,7 @@ def forward( q, k, v, - fake_dtype, + out_nominal_dtype, fused_attention_backend, attn_bias, cu_seqlens_q_padded, @@ -1066,32 +1289,38 @@ def forward( qkv_layout, attn_bias_type, attn_mask_type, + softmax_type, window_size, rng_gen, + softmax_offset, + return_max_logit, + is_graph_capturing(), ) - out_save = out_ret + out = out_ + out_ret = out_ fp8_tensors = (None, None, None, None) + qkvo_tensors = (q, k, v, out) - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + nvtx_range_pop(f"{nvtx_label}") - from transformer_engine.pytorch.cpu_offload import ( - CPUOffloadEnabled, - mark_activation_offload, - ) + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = is_bwd_fp8 + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + ctx.nominal_dtype = out_nominal_dtype - if CPUOffloadEnabled: + if is_cpu_offload_enabled() and NVTE_CPU_OFFLOAD_V1: if ctx.fp8: tensor_list = fp8_tensors else: - tensor_list = [q, k, v, out_save] + tensor_list = [q, k, v, out] - qkv_layout = "sbhd_sbhd_sbhd" mark_activation_offload(*tensor_list) mark_activation_offload(*aux_ctx_tensors) ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 - qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) + tensors_to_save, tensor_objects = prepare_for_saving( *fp8_tensors, *qkvo_tensors, @@ -1105,11 +1334,14 @@ def forward( ctx.tensor_objects = tensor_objects ctx.fp8_meta = fp8_meta + ctx.layer_number = layer_number + ctx.QKV_quantizer = QKV_quantizer + ctx.O_quantizer = O_quantizer ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.S_quantizer = S_quantizer - if ctx.fp8: + if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer): ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() @@ -1118,9 +1350,33 @@ def forward( ctx.attn_scale = attn_scale ctx.dropout_p = dropout_p ctx.fast_zero_fill = fast_zero_fill - ctx.qkv_layout = qkv_layout + + if NVTE_CPU_OFFLOAD_V1: + # If interleaved tensor is offloaded, reloaded tensor will be + # non-interleaved, so we need to modify the QKV layout + # for backward + if is_current_layer_offloaded() and is_cpu_offload_enabled(): + reload_layout = "" + split_list = qkv_layout.split("_") + for split in split_list: + temp_layout = "" + rep_count = 1 + for s in split: + if s.isalpha(): + temp_layout = temp_layout + s + else: + rep_count = int(s) + for _ in range(rep_count): + reload_layout = reload_layout + temp_layout + "_" + ctx.qkv_layout = reload_layout[:-1] + else: + ctx.qkv_layout = qkv_layout + else: + ctx.qkv_layout = qkv_layout + ctx.attn_bias_type = attn_bias_type ctx.attn_mask_type = attn_mask_type + ctx.softmax_type = softmax_type ctx.window_size = window_size ctx.fused_attention_backend = ( fused_attention_backend if (IS_HIP_EXTENSION or ctx.fp8) else FusedAttnBackend["F16_arbitrary_seqlen"] @@ -1128,22 +1384,22 @@ def forward( ctx.use_FAv2_bwd = use_FAv2_bwd ctx.deterministic = deterministic + if return_max_logit: + return out_ret, *max_logit return out_ret @staticmethod - def backward(ctx, d_out): + def backward(ctx, d_out, *_args): # pylint: disable=missing-function-docstring - if ctx.is_output_fp8: - assert isinstance( - d_out, Float8Tensor - ), "Gradient of the DPA output must be in Float8Tensor type for FP8 MHA." - - # FP16/BF16 attn: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = False: fake_dtype = torch.float16 or torch.bfloat16 - # FP8 attn, is_output_fp8 = True: fake_dtype = torch.float8_e5m2 - fake_dtype = d_out.dtype - d_out = d_out.contiguous() + # d_out is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(d_out, QuantizedTensorStorage): + d_out = ctx.dO_quantizer(d_out) + if not ctx.use_FAv2_bwd: + d_out._data = d_out._data.contiguous() + elif not ctx.use_FAv2_bwd: + d_out = d_out.contiguous() ( q_fp8, k_fp8, @@ -1197,16 +1453,55 @@ def backward(ctx, d_out): dk = dk[..., : d_out.shape[-1]] dv = dv[..., : d_out.shape[-1]] else: - with torch.cuda.nvtx.range("_FusedAttn"): + with torch.cuda.nvtx.range("FusedAttnFunc.backward"): + # get nominal data type of dq, dk, dv + # FP16/BF16 attention: torch.float16 or torch.bfloat16 + # FP8 attention: torch.float16 or torch.bfloat16 + dqkv_nominal_dtype = ctx.nominal_dtype + if ctx.fp8: + # d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 if ctx.is_output_fp8: d_out_fp8 = d_out else: d_out_fp8 = ctx.dO_quantizer(d_out) - dqkv_dtype = TE_DType[d_out_fp8._data.dtype] - # q_fp8, k_fp8, v_fp8, out_fp8: torch.float8_e4m3fn - # d_out_fp8, dq_fp8, dk_fp8, dv_fp8: torch.float8_e5m2 - dq_fp8, dk_fp8, dv_fp8, *rest = fused_attn_bwd( + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # get tex.DType for dq, dk, dv data + dqkv_te_dtype = d_out_fp8._fp8_dtype + + # q_fp8, k_fp8, v_fp8, out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16, + # fp8_dtype = tex.DType.kFloat8E4M3 + # d_out_fp8: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # out_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E4M3 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + # + # dq_, dk_, dv_: + # DelayedScaling: Float8Tensor; dtype = torch.float16 or torch.bfloat16 + # fp8_dtype = tex.DType.kFloat8E5M2 + # Float8CurrentScaling: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + out_ = ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ) + dq_, dk_, dv_, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -1214,10 +1509,10 @@ def backward(ctx, d_out): q_fp8, k_fp8, v_fp8, - out_fp8, + out_, d_out_fp8, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1231,44 +1526,46 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, + is_graph_capturing(), ) - # is_input_fp8 = False: dq, dk, dv: torch.float16 or torch.bfloat16 - # is_input_fp8 = True: dq, dk, dv: torch.float8_e5m2 - if not ctx.is_input_fp8: - qkv_group = len(ctx.qkv_layout.replace("paged_kv_", "").split("_")) - if qkv_group == 1: - dim = ctx.qkv_layout.find("3") - dqkv_fp8_data = combine_tensors( - [dq_fp8._data, dk_fp8._data, dv_fp8._data], dim - ) - dqkv_fp8 = dq_fp8.make_like( - tensor=dq_fp8, data=dqkv_fp8_data, shape=dqkv_fp8_data.shape - ) - dqkv = dqkv_fp8.dequantize() - dq, dk, dv = SplitAlongDim.apply(dqkv, dim, [1, 1, 1], True) - if qkv_group == 2: - dq = dq_fp8.dequantize() - dim = ctx.qkv_layout.split("_")[1].find("2") - dkv_fp8 = combine_tensors([dk_fp8, dv_fp8], dim) - dkv_c_fp8 = dkv_fp8.view( - -1, dkv_fp8.shape[-3] * dkv_fp8.shape[-2] * dkv_fp8.shape[-1] - ) - dkv = dkv_c_fp8.dequantize() - dk, dv = SplitAlongDim.apply(dkv, dim, [1, 1], True) - if qkv_group == 3: - dq = dq_fp8.dequantize() - dk = dk_fp8.dequantize() - dv = dv_fp8.dequantize() - else: - dq, dk, dv = dq_fp8, dk_fp8, dv_fp8 + # dq, dk, dv: torch.Tensor; dtype = torch.float16 or torch.bfloat16 + dq, dk, dv = dq_, dk_, dv_ + is_float8tensor = isinstance(dq_, Float8Tensor) + if is_float8tensor and not ctx.is_input_fp8: + # return in F16 + dq, dk, dv = combine_and_dequantize( + ctx.qkv_layout, + dq_, + dk_, + dv_, + src_nominal_dtype=dq_.dtype, + ) + if not is_float8tensor and ctx.is_input_fp8: + # return in FP8 + dq, dk, dv = combine_and_quantize( + ctx.qkv_layout, dq_, dk_, dv_, ctx.dQKV_quantizer + ) + + # print quantizers + print_quantizers( + "FusedAttnFunc.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) else: - if isinstance(d_out, QuantizedTensor): - d_out = d_out.dequantize() - dqkv_dtype = TE_DType[d_out.dtype] - # q, k, v, out, d_out, dq, dk, dv: torch.float16 or torch.bfloat16 + if isinstance(d_out, QuantizedTensorStorage): + d_out = d_out.dequantize(dtype=ctx.nominal_dtype) + dqkv_te_dtype = TE_DType[d_out.dtype] + # q, k, v, out, d_out, dq, dk, dv: torch.Tensor; torch.float16 or torch.bfloat16 dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, @@ -1279,8 +1576,8 @@ def backward(ctx, d_out): v, out, d_out, - fake_dtype, - dqkv_dtype, + dqkv_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, ctx.fused_attention_backend, cu_seqlens_q_padded, @@ -1294,42 +1591,18 @@ def backward(ctx, d_out): ctx.qkv_layout, ctx.attn_bias_type, ctx.attn_mask_type, + ctx.softmax_type, ctx.window_size, ctx.deterministic, + is_graph_capturing(), ) - # if no_bias or alibi, return dqkv - if ctx.attn_bias_type in ["no_bias", "alibi"]: - return ( - None, - None, - None, - None, - None, - None, - None, - None, - None, - dq, - dk, - dv, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - ) - # else, return (dqkv, dbias) + d_bias = None + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + d_softmax_offset = None + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] return ( None, None, @@ -1343,7 +1616,10 @@ def backward(ctx, d_out): dq, dk, dv, - rest[0], + d_bias, + None, + None, + None, None, None, None, @@ -1356,6 +1632,8 @@ def backward(ctx, d_out): None, None, None, + d_softmax_offset, + None, None, None, ) @@ -1397,7 +1675,11 @@ def __init__( attention_type: str = "self", layer_number: Optional[int] = None, deterministic: bool = False, + softmax_type: str = "vanilla", + return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.softmax_scale = softmax_scale @@ -1409,6 +1691,8 @@ def __init__( ) == "1" and get_device_compute_capability() == (9, 0) self.layer_number = 1 if layer_number is None else layer_number self.deterministic = deterministic + self.softmax_type = softmax_type + self.return_max_logit = return_max_logit def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument """ @@ -1460,6 +1744,8 @@ def forward( quantizers=None, pad_between_seqs: bool = False, inference_params: Optional[InferenceParams] = None, + softmax_offset: torch.Tensor = None, + fp8_output: bool = False, ) -> torch.Tensor: """fused attention fprop""" assert ( @@ -1560,15 +1846,27 @@ def forward( ) if fp8: + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] assert fused_attention_backend == tex.NVTE_Fused_Attn_Backend.NVTE_FP8, ( f"cuDNN attention sub-backend {int(tex.NVTE_Fused_Attn_Backend.NVTE_FP8)}" " is required for FP8 attention!" ) assert fp8_meta is not None, "FP8 metadata fp8_meta is required for FP8 attention!" - assert not context_parallel or fp8_meta["recipe"].reduce_amax, ( - "Amax reduction across TP+CP group is necessary when using context parallelism with" - " FP8!" - ) + if fp8_recipe.delayed(): + assert not context_parallel or fp8_recipe.reduce_amax, ( + "Amax reduction across TP+CP group is necessary when using context parallelism" + " with FP8!" + ) + if fp8_recipe.float8_current_scaling() and context_parallel: + all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) + for q in all_quantizers: + if isinstance(q, Float8CurrentScalingQuantizer): + q.with_amax_reduction = True + q.amax_reduction_group = ( + cp_group[0] if cp_comm_type == "a2a+p2p" else cp_group + ) if context_parallel: assert ( @@ -1610,6 +1908,11 @@ def forward( fp8_meta=fp8_meta, quantizers=quantizers, pad_between_seqs=pad_between_seqs, + softmax_type=self.softmax_type, + softmax_offset=softmax_offset, + fp8_output=fp8_output, + layer_number=self.layer_number, + return_max_logit=self.return_max_logit, ) else: with self.attention_dropout_ctx(): @@ -1633,6 +1936,7 @@ def forward( qkv_layout, core_attention_bias_type, attn_mask_type, + self.softmax_type, window_size, None, # rng_gen fused_attention_backend, @@ -1641,7 +1945,14 @@ def forward( fp8_meta, quantizers, self.deterministic, + softmax_offset, + fp8_output, + self.layer_number, + self.return_max_logit, ) + if self.return_max_logit: + # ...hd -> ...(hd) + return output[0].view(*output[0].shape[:-2], -1), output[1] # ...hd -> ...(hd) return output.view(*output.shape[:-2], -1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7eedd688f..096eca809 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,7 +12,6 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.utils import ( - combine_tensors, get_cudnn_version, nvtx_range_pop, nvtx_range_push, @@ -23,8 +22,11 @@ fused_attn_bwd, FusedAttnBackend, ) -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage from transformer_engine.pytorch.jit import jit_fuser +from transformer_engine.pytorch.graph import is_graph_capturing from transformer_engine.pytorch.constants import ( dist_group_type, TE_DType, @@ -35,7 +37,8 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( + +from transformer_engine.pytorch.quantized_tensor import ( prepare_for_saving, restore_from_saved, ) @@ -44,11 +47,18 @@ import transformer_engine.pytorch.attention.dot_product_attention.utils as dpa_utils from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils as fa_utils, + combine_and_quantize, + combine_and_dequantize, + print_quantizers, ) _cu_seqlens_info_with_cp_cache = {} _seq_chunk_ids_cache_for_reordering_before_attn = {} _seq_chunk_ids_cache_for_reordering_after_attn = {} +_softmax_offset_chunk_ids_cache = {} + +# Float8CurrentScaling: fused_attn_bwd takes O in FP8 by default, this flag allows it in F16 +_dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" def flash_attn_p2p_communicate( @@ -228,11 +238,11 @@ def get_seq_chunk_ids_for_reordering_after_attn(cp_size, device): @jit_fuser def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication before attention compute.""" - # [cp, b, s, np//cp, hn] -> [b, cp, s, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [cp, b, s, h//cp, d] -> [b, cp, s, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp, s, b, h//cp, d] x = x.movedim(0, seq_dim).contiguous() - # [b, cp, s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp, s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp, s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp, s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 2) :]) # reorder the sequence chunks x = torch.index_select(x, dim=seq_dim, index=chunk_ids_for_a2a) @@ -242,13 +252,13 @@ def reorder_seq_chunks_for_a2a_before_attn(x, chunk_ids_for_a2a, seq_dim, cp_siz @jit_fuser def reorder_seq_chunks_for_a2a_after_attn(x, chunk_ids_for_a2a, seq_dim, cp_size): """Reorder sequence chunk for A2A communication after attention compute.""" - # [b, cp*2, s//2, np//cp, hn] -> [cp*2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [cp*2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.movedim(seq_dim, 0).contiguous() # reorder the sequence chunks x = torch.index_select(x, dim=0, index=chunk_ids_for_a2a) - # [cp*2, b, s//2, np//cp, hn] -> [cp, 2, b, s//2, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp, 2, s//2, b, np//cp, hn] + # [cp*2, b, s//2, h//cp, d] -> [cp, 2, b, s//2, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp, 2, s//2, b, h//cp, d] x = x.view(cp_size, 2, *x.shape[1:]) return x @@ -280,16 +290,16 @@ def flash_attn_a2a_communicate( x = reorder_seq_chunks_for_a2a_before_attn( x, chunk_ids_for_a2a, seq_dim, cp_size ) - # [b, cp*2, s//2, np//cp, hn] -> [b, cp*s, np//cp, hn] - # or [cp*2, s//2, b, np//cp, hn] -> [cp*s, b, np//cp, hn] + # [b, cp*2, s//2, h//cp, d] -> [b, cp*s, h//cp, d] + # or [cp*2, s//2, b, h//cp, d] -> [cp*s, b, h//cp, d] a2a_outputs[i - 2] = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim + 2) :]) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, s, np, hn] -> [b, s, cp, np//cp, hn] - # or [s, b, np, hn] -> [s, b, cp, np//cp, hn] + # [b, s, h, d] -> [b, s, cp, h//cp, d] + # or [s, b, h, d] -> [s, b, cp, h//cp, d] x = x.view(*x.shape[:-2], cp_size, x.shape[-2] // cp_size, x.shape[-1]) - # [b, s, cp, np//cp, hn] -> [cp, b, s, np//cp, hn] - # or [s, b, cp, np//cp, hn] -> [cp, s, b, np//cp, hn] + # [b, s, cp, h//cp, d] -> [cp, b, s, h//cp, d] + # or [s, b, cp, h//cp, d] -> [cp, s, b, h//cp, d] a2a_inputs[i] = x.movedim(-3, 0).contiguous() else: for i in range(len(a2a_inputs) + 2): @@ -300,8 +310,8 @@ def flash_attn_a2a_communicate( ) if i < len(a2a_inputs): x = a2a_inputs[i] - # [b, cp*s, np//cp, hn] -> [b, cp*2, s//2, np//cp, hn] - # or [cp*s, b, np//cp, hn] -> [cp*2, s//2, b, np//cp, hn] + # [b, cp*s, h//cp, d] -> [b, cp*2, s//2, h//cp, d] + # or [cp*s, b, h//cp, d] -> [cp*2, s//2, b, h//cp, d] x = x.view(*x.shape[:seq_dim], cp_size * 2, -1, *x.shape[(seq_dim + 1) :]) # reorder the sequence chunks a2a_inputs[i] = reorder_seq_chunks_for_a2a_after_attn( @@ -311,16 +321,65 @@ def flash_attn_a2a_communicate( with torch.cuda.stream(cp_stream): a2a_reqs[i - 2].wait() x = a2a_outputs[i - 2] - # [cp, 2, b, s//2, np//cp, hn] -> [b, 2, s//2, cp, np//cp, hn] - # or [cp, 2, s//2, b, np//cp, hn] -> [2, s//2, b, cp, np//cp, hn] + # [cp, 2, b, s//2, h//cp, d] -> [b, 2, s//2, cp, h//cp, d] + # or [cp, 2, s//2, b, h//cp, d] -> [2, s//2, b, cp, h//cp, d] x = x.movedim(0, -3).movedim(0, seq_dim).contiguous() - # [b, 2, s//2, cp, np//cp, hn] -> [b*s, np, hn] - # or [2, s//2, b, cp, np//cp, hn] -> [s*b, np, hn] + # [b, 2, s//2, cp, h//cp, d] -> [b*s, h, d] + # or [2, s//2, b, cp, h//cp, d] -> [s*b, h, d] a2a_outputs[i - 2] = x.view(-1, x.shape[-3] * x.shape[-2], x.shape[-1]) torch.cuda.current_stream().wait_stream(cp_stream) return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs +def flash_attn_a2a_communicate_softmax_offset( + tensor: torch.Tensor, + h_dim: int, + cp_size: int, + cp_group: dist_group_type, + cp_stream: torch.cuda.Stream, + before_attn: bool, +) -> Union[torch.Tensor, List[torch.Tensor]]: + """Split/AllGather communication for softmax offset.""" + if tensor is None: + return None + + global _softmax_offset_chunk_ids_cache + device = tensor.device + if (cp_size, device) not in _softmax_offset_chunk_ids_cache: + chunk_ids = torch.arange(cp_size, dtype=torch.int32, device=device) + _softmax_offset_chunk_ids_cache[(cp_size, device)] = chunk_ids + else: + chunk_ids = _softmax_offset_chunk_ids_cache[(cp_size, device)] + + if before_attn: + # softmax_offset: split round-robin to CP ranks + # [1, h, 1, 1] -> [1, cp, h//cp, 1, 1] + shape = tensor.shape + tensor = tensor.view( + *shape[:h_dim], cp_size, shape[h_dim] // cp_size, *shape[(h_dim + 1) :] + ) + rank = get_distributed_rank(cp_group) + output = torch.index_select(tensor, dim=h_dim, index=chunk_ids[rank]) + output = output.view(*shape[:h_dim], -1, *shape[(h_dim + 1) :]) + else: + # d_softmax_offset: all-gather from all ranks to all ranks + # [1, h//cp, 1, 1] -> [1, h, 1, 1] + inp = tensor.view(-1) + output = torch.empty(cp_size * inp.shape[0], dtype=tensor.dtype, device=device) + with torch.cuda.stream(cp_stream): + torch.distributed.all_gather_into_tensor( + output, + inp, + group=cp_group, + async_op=False, + ) + torch.cuda.current_stream().wait_stream(cp_stream) + output = output.view( + *tensor.shape[:h_dim], cp_size * tensor.shape[h_dim], *tensor.shape[h_dim + 1 :] + ) + return output + + def _get_cu_seqlens_info_with_cp( batch_size: int, max_seqlen: int, @@ -420,6 +479,591 @@ def get_fa_args( ] +def cp_p2p_fwd_prepare_qkv( + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + step, + cp_size, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P forward""" + cu_seqlens_q_per_step = None + cu_seqlens_kv_per_step = None + if section in ["diagonal", "all"]: + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + rank_ = rank if section == "diagonal" else (rank - step) % cp_size + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank_, True, True + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, k_part, v_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part]] + + elif section == "lower-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + False, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // cp_size + cu_seqlens_kv_per_step = cu_seqlens_kv // (cp_size * 2) + else: + cu_seqlens_q_per_step = cu_seqlens_q + cu_seqlens_kv_per_step = cu_seqlens_kv_half + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part = q_part.view(q_part.shape[0], -1, *q_part.shape[-2:]) + # [b, 2, sk//2, h, d] -> [b, sk//2, h, d] + k_part = k_part[:, 0, ...] + v_part = v_part[:, 0, ...] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part = q_part.view(-1, *q_part.shape[-3:]) + # [2, sk//2, b, h, d] -> [sk//2, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + + elif section == "upper-triangle": + if pad_between_seqs: + cu_seqlens_q_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True + ) + cu_seqlens_kv_per_step = get_cu_seqlens_on_cp_rank( + cu_seqlens_kv, + cu_seqlens_kv_padded, + cp_size, + (rank - step) % cp_size, + True, + True, + ) + elif qkv_format == "thd": + cu_seqlens_q_per_step = cu_seqlens_q // (cp_size * 2) + cu_seqlens_kv_per_step = cu_seqlens_kv // cp_size + else: + cu_seqlens_q_per_step = cu_seqlens_q_half + cu_seqlens_kv_per_step = cu_seqlens_kv + + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part = q_part[:, 1, ...] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part = q_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part = tex.thd_read_half_tensor(q_part, cu_seqlens_q_padded, 1) + + return q_part, k_part, v_part, cu_seqlens_q_per_step, cu_seqlens_kv_per_step + + +def cp_p2p_fwd_fused_attn( + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step, + O_quantizer_per_step, + rank, + step, + cp_size, + return_max_logit, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FusedAttention backend""" + attn_bias_inputs = None + max_seqlen_q_ = None + max_seqlen_kv_ = None + cu_seqlens_q_ = None + cu_seqlens_kv_ = None + attn_mask_type_ = None + cu_seqlens_q_padded_ = None + cu_seqlens_kv_padded_ = None + if section in ["diagonal", "all"]: + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias[..., idx, :], + attn_bias[..., (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = attn_mask_type + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + elif section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = attn_bias[..., idx, :].contiguous() + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = ( + cu_seqlens_kv_padded // 2 if cu_seqlens_kv_padded is not None else None + ) + elif section == "upper-triangle": + q_part = q_part.contiguous() + if attn_bias is not None: + idx = (rank - step) % cp_size + attn_bias_inputs = torch.cat( + ( + attn_bias_[..., 1, :, idx, :], + attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], + ), + dim=-1, + ).contiguous() + max_seqlen_q_ = max_seqlen_q // 2 + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + cu_seqlens_q_padded_ = cu_seqlens_q_padded // 2 if cu_seqlens_q_padded is not None else None + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step + fp8_meta_kwargs["o_quantizer"] = O_quantizer_per_step + + out_per_step, aux_ctx_tensors, *max_logit = fused_attn_fwd( + is_training, + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_, + cu_seqlens_kv_, + q_part, + k_part, + v_part, + fake_dtype=fwd_nominal_dtype, + fused_attention_backend=fused_attn_backend, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + attn_bias=attn_bias_inputs, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + **fp8_meta_kwargs, + return_max_logit=return_max_logit, + cuda_graph=is_graph_capturing(), + ) + + if fp8: + softmax_lse_per_step, _, rng_states = aux_ctx_tensors + else: + softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors + attn_bias = rest[0] if len(rest) > 0 else None + + if return_max_logit: + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, *max_logit + return out_per_step, softmax_lse_per_step, rng_states, attn_bias, None + + +def cp_p2p_fwd_flash_attn( + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + q_part, + k_part, + v_part, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + section, +): + """Per-tile forward call of CP P2P with FlashAttention backend""" + cu_seqlens_q_ = cu_seqlens_q_per_step + cu_seqlens_kv_ = cu_seqlens_kv_per_step + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + causal_ = False + if section in ["diagonal", "all"]: + causal_ = section == "diagonal" + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + if section in ["lower-triangle", "upper-triangle"]: + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_forward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_forward_kwargs["window_size_left"] = -1 + fa_forward_kwargs["window_size_right"] = -1 + + fa_forward_args_thd = get_fa_args( + True, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_, + cu_seqlens_kv=cu_seqlens_kv_, + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + ) + fa_outputs = flash_attn_fwd( + q_part, + k_part, + v_part, + *fa_forward_args_thd, + causal=causal_, + **fa_forward_kwargs, + ) + rng_states = None + if not fa_utils.v2_7_0_plus: + out_per_step = fa_outputs[4] + softmax_lse_per_step = fa_outputs[5] + if not use_flash_attn_3: + rng_states = fa_outputs[7] + else: + out_per_step = fa_outputs[0] + softmax_lse_per_step = fa_outputs[1] + if not use_flash_attn_3: + rng_states = fa_outputs[3] + + return out_per_step, softmax_lse_per_step, rng_states + + +def cp_p2p_bwd_prepare_qkv( + q_part, + k_part, + v_part, + out_part, + dout_part, + qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + section, +): + """Prepare q, k, v and cu_seqlens for CP P2P backward""" + if section in ["diagonal", "all"]: + if qkv_format == "bshd": + # [b, 2, s//2, h, d] -> [b, s, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) + for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif qkv_format == "sbhd": + # [2, s//2, b, h, d] -> [s, b, h, d] + q_part, k_part, v_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, k_part, v_part, out_part, dout_part] + ] + elif section == "lower-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq, h, d] + q_part, out_part, dout_part = [ + x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q_part, out_part, dout_part] + ] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part = k_part[:, 0] + v_part = v_part[:, 0] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq, b, h, d] + q_part, out_part, dout_part = [ + x.view(-1, *x.shape[-3:]) for x in [q_part, out_part, dout_part] + ] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part = k_part[0] + v_part = v_part[0] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) + v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) + elif section == "upper-triangle": + if qkv_format == "bshd": + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + q_part, out_part, dout_part = q_part[:, 1], out_part[:, 1], dout_part[:, 1] + # [b, 2, sk//2, h, d] -> [b, sk, h, d] + k_part, v_part = [x.view(x.shape[0], -1, *x.shape[-2:]) for x in [k_part, v_part]] + elif qkv_format == "sbhd": + # [2, sq//2, b, h, d] -> [sq//2, b, h, d] + q_part, out_part, dout_part = q_part[1], out_part[1], dout_part[1] + # [2, sk//2, b, h, d] -> [sk, b, h, d] + k_part, v_part = [x.view(-1, *x.shape[-3:]) for x in [k_part, v_part]] + elif qkv_format == "thd": + # [t, h, d] -> [t/2, h, d] + q_part, out_part, dout_part = [ + tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) + for x in [q_part, out_part, dout_part] + ] + + return q_part, k_part, v_part, out_part, dout_part + + +def cp_p2p_bwd_fused_attn( + fp8, + fp8_recipe, + q_fp8, + kv_fp8, + out_fp8, + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + max_seqlen_q, + max_seqlen_kv, + step, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + deterministic, + fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + S_quantizer, + dP_quantizer_per_step, + dQKV_quantizer_per_step, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FusedAttention backend""" + if fp8: + aux_tensors = [ + softmax_lse, + softmax_lse, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + cu_seqlens_q_padded_ = cu_seqlens_q_padded + cu_seqlens_kv_padded_ = cu_seqlens_kv_padded + attn_mask_type_ = attn_mask_type + + if section == "lower-triangle": + k_part = k_part.contiguous() + v_part = v_part.contiguous() + max_seqlen_kv_ = max_seqlen_kv // 2 + cu_seqlens_kv_padded_ = None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + elif section == "upper-triangle": + q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] + if fp8: + aux_tensors = [ + softmax_lse_, + softmax_lse_, + rng_states[cp_size - step - 1], + ] + else: + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + + max_seqlen_q_ = max_seqlen_q // 2 + cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 + attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" + + if attn_dbias is not None: + aux_tensors += [attn_biases[cp_size - step - 1]] + + fp8_meta_kwargs = {} + if fp8: + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip( + [q_fp8, kv_fp8, kv_fp8], + [q_part, k_part, v_part], + ) + ] + if not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16): + out_part = Float8Tensor.make_like(out_fp8, data=out_part, dtype=fwd_nominal_dtype) + dout_part = Float8Tensor.make_like(dout_fp8, data=dout_part, dtype=bwd_nominal_dtype) + fp8_meta_kwargs["s_quantizer"] = S_quantizer + fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step + fp8_meta_kwargs["dqkv_quantizer"] = dQKV_quantizer_per_step + + dq, dk, dv, dbias, *_ = fused_attn_bwd( + max_seqlen_q_, + max_seqlen_kv_, + cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv_per_step[cp_size - step - 1], + q_part, + k_part, + v_part, + out_part, + dout_part, + bwd_nominal_dtype, + bwd_output_te_dtype, + aux_tensors, + fused_attn_backend, + cu_seqlens_q_padded=cu_seqlens_q_padded_, + cu_seqlens_kv_padded=cu_seqlens_kv_padded_, + attn_scale=softmax_scale, + dropout=dropout_p, + qkv_layout=qkv_layout, + attn_mask_type=attn_mask_type_, + attn_bias_type=attn_bias_type, + deterministic=deterministic, + cuda_graph=is_graph_capturing(), + **fp8_meta_kwargs, + ) + + return dq, dk, dv, dbias + + +def cp_p2p_bwd_flash_attn( + use_flash_attn_3, + qkv_format, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + step, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + q_part, + k_part, + v_part, + out_part, + dout_part, + section, +): + """Per-tile backward call of CP P2P with FlashAttention backend""" + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, -1) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = -1 + if not use_flash_attn_3: + fa_backward_kwargs["rng_state"] = rng_states[cp_size - step - 1] + max_seqlen_q_ = max_seqlen_q + max_seqlen_kv_ = max_seqlen_kv + softmax_lse__ = softmax_lse + causal_ = False + if section == "diagonal": + if use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): + fa_backward_kwargs["window_size"] = (-1, 0) + elif fa_utils.v2_7_0_plus: + fa_backward_kwargs["window_size_left"] = -1 + fa_backward_kwargs["window_size_right"] = 0 + causal_ = True + elif section == "lower-triangle": + max_seqlen_kv_ = max_seqlen_kv // 2 + elif section == "upper-triangle": + max_seqlen_q_ = max_seqlen_q // 2 + softmax_lse__ = softmax_lse_ + + fa_backward_args_thd = get_fa_args( + False, + use_flash_attn_3, + qkv_format, + cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], + cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + max_seqlen_q=max_seqlen_q_, + max_seqlen_kv=max_seqlen_kv_, + dq=dq, + dk=dk, + dv=dv, + ) + flash_attn_bwd( + dout_part, + q_part, + k_part, + v_part, + out_part, + softmax_lse__, + *fa_backward_args_thd, + causal=causal_, + **fa_backward_kwargs, + ) + + return dq, dk, dv + + class AttnFuncWithCPAndKVP2P(torch.autograd.Function): """ Attention implementation with context parallelism. Exchange KV between CP ranks @@ -453,6 +1097,7 @@ def forward( attn_bias, deterministic, use_fused_attention, + return_max_logit, fp8, fp8_meta, cp_group, @@ -461,30 +1106,24 @@ def forward( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") - enable_mla = k.shape[-1] != v.shape[-1] - if softmax_scale is None: - softmax_scale = q.shape[-1] ** (-0.5) + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.forward" + nvtx_range_push(f"{nvtx_label}") + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} + cp_group_a2a = None + cp_size_a2a = 1 + rank_a2a = 0 if isinstance(cp_group, list): - assert ( - qkv_format != "thd" - ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" - assert attn_bias_type == "no_bias", ( - f"{attn_bias_type} bias type is not supported with hierarchical CP implementation" - " yet!" - ) cp_group_a2a = cp_group[0] cp_size_a2a = get_distributed_world_size(cp_group_a2a) rank_a2a = get_distributed_rank(cp_group_a2a) cp_group = cp_group[1] - else: - cp_group_a2a = None - cp_size_a2a = 1 - rank_a2a = 0 - cp_size = get_distributed_world_size(cp_group) rank = get_distributed_rank(cp_group) send_dst = cp_global_ranks[(rank + 1) % cp_size * cp_size_a2a + rank_a2a] @@ -494,18 +1133,19 @@ def forward( device_compute_capability < (10, 0) and cp_size == 2 ) + # set up attention args + enable_mla = k.shape[-1] != v.shape[-1] causal = "causal" in attn_mask_type - padding = "padding" in attn_mask_type + + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) batch_dim = None seq_dim = None cu_seqlens_q_half, cu_seqlens_kv_half = None, None + qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format if qkv_format in ["bshd", "sbhd"]: seq_dim = qkv_format.index("s") - if enable_mla: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format - else: - qkv_layout = qkv_format + "_" + qkv_format[:-2] + "2" + qkv_format[-2:] cu_seqlens_q_padded, cu_seqlens_kv_padded = None, None if use_fused_attention: batch_dim = qkv_format.index("b") @@ -516,7 +1156,6 @@ def forward( q.shape[batch_dim], max_seqlen_kv, cp_size, cu_seqlens_kv ) else: - qkv_layout = qkv_format + "_" + qkv_format + "_" + qkv_format cu_seqlens_q_padded = cu_seqlens_q_padded // cp_size cu_seqlens_kv_padded = cu_seqlens_kv_padded // cp_size @@ -526,79 +1165,116 @@ def forward( cu_seqlens_kv_per_step = [None for _ in range(cp_size)] fused_attn_backend = None - qkv_dtype = q.dtype amax_per_step = None S_quantizer_per_step = [None for _ in range(cp_size)] - O_CP_quantizer_per_step = [None for _ in range(cp_size)] - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False + O_quantizer_per_step = [None for _ in range(cp_size)] + max_logit_per_step = [None for _ in range(cp_size)] + max_logit = None + + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + fwd_nominal_dtype = q.dtype + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] ( QKV_quantizer, O_quantizer, - O_CP_quantizer, S_quantizer, dQKV_quantizer, - dQKV_CP_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=True) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers) + + q_f16 = None + q_fp8, k_fp8, v_fp8 = (None, None, None) + # communicate for the 'a2a' part of 'a2a+p2p' + if cp_size_a2a > 1: + if fp8 and is_input_fp8: + QKV_quantizer = q._quantizer + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = (q._data, k._data, v._data) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) + q, k, v = flash_attn_a2a_communicate( + [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True + ) + if fp8 and is_input_fp8: + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q, k, v]) + ] + q, k, v = q_fp8, k_fp8, v_fp8 + # convert qkv to the right type if fp8: - if use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha - if is_input_fp8: - QKV_quantizer = q._quantizer - q, k, v = q._data, k._data, v._data - else: - q_f16, k_f16, v_f16 = q, k, v - if cp_size_a2a == 1 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q = QKV_quantizer(q_f16)._data - if int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - k, v = [QKV_quantizer(x)._data for x in [k_f16, v_f16]] - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - # partial result quantizer - for i in range(cp_size): - S_quantizer_per_step[i] = S_quantizer.copy() - S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - O_CP_quantizer_per_step[i] = O_CP_quantizer.copy() - O_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + if is_input_fp8: + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_fp8, k_fp8, v_fp8 = q, k, v + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] else: - assert False, "FP8 is only supported with Fused Attention!" + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q_fp8, k_fp8, v_fp8: Float8Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=torch.uint8 + q_f16 = q + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> before: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + + # amax_per_step[0]: amax_s x cp_size + # amax_per_step[1]: amax_o x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + S_quantizer_per_step[i] = S_quantizer.copy() + S_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + O_quantizer_per_step[i] = O_quantizer.copy() + O_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: + # q_f16: torch.Tensor, dtype=fwd_nominal_dtype + # q, k, v: torch.Tensor, dtype=fwd_nominal_dtype q_f16 = q if use_fused_attention: fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] + if return_max_logit: + max_logit_per_step = [ + torch.empty(q.shape[-2], dtype=q.dtype, device=q.device) for _ in range(cp_size) + ] - if cp_size_a2a > 1: - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size_a2a, q.device) - - q, k, v = flash_attn_a2a_communicate( - [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size_a2a, cp_group_a2a, cp_stream, True - ) - if not fp8: - q_f16 = q - elif not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16 = q - q = QKV_quantizer(q_f16)._data - + # split qkv to two halves and prepare for load balancing assert qkv_format == "thd" or ( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" if causal: if qkv_format == "bshd": - # [b, s, np, hn] -> [b, 2, s//2, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] q, k, v = [x.view(x.shape[0], 2, x.shape[1] // 2, *x.shape[2:]) for x in [q, k, v]] elif qkv_format == "sbhd": - # [s, b, np, hn] -> [2, s//2, b, np, hn] + # [s, b, h, d] -> [2, s//2, b, h, d] q, k, v = [x.view(2, x.shape[0] // 2, *x.shape[1:]) for x in [q, k, v]] + attn_bias_ = None if attn_bias is not None: assert len(attn_bias.shape) == 4, ( "Only support bias shape of [b, h, sq, sk] for forward, " @@ -607,7 +1283,7 @@ def forward( assert ( attn_bias.shape[-2] % 2 == 0 and attn_bias.shape[-1] % (2 * cp_size) == 0 ), "Sequence length does not meet divisible requirements!" - # [b, np, sq, sk] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_bias_ = attn_bias.view( *attn_bias.shape[:-2], 2, @@ -615,12 +1291,14 @@ def forward( 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size), ) - # [b, np, sq, sk] -> [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, sk] -> [b, h, sq, 2*cp, sk//(2*cp)] attn_bias = attn_bias.view( *attn_bias.shape[:-1], 2 * cp_size, attn_bias.shape[-1] // (2 * cp_size) ) - assert q.shape[-1] % 8 == 0, "hidden size per attention head should be multiple of 8" + # stats tensor shape: + # BHS1 before cuDNN 9.6 or flash-attention v2.6/v3 + # TH1 after cuDNN 9.6 or flash-attention v2.6/v3 softmax_lse_in_packed_format = False if qkv_format == "thd": if use_fused_attention: @@ -628,7 +1306,9 @@ def forward( else: softmax_lse_in_packed_format = fa_utils.v2_6_0_plus or use_flash_attn_3 + # set up args for FlashAttention backend flash_attn_fwd = None + fa_forward_kwargs = {} if not use_fused_attention: fa_forward_kwargs = {"softmax_scale": softmax_scale} if use_flash_attn_3: @@ -667,11 +1347,9 @@ def forward( if fa_utils.v2_6_0_plus: fa_forward_kwargs["softcap"] = 0.0 - # Flash Attn inputs + # set up inputs for forward q_inputs = [None, None] kv_inputs = [None, None] - attn_bias_inputs = [None, None] - # Flash Attn outputs out_per_step = [None for _ in range(cp_size)] softmax_lse_per_step = [None for _ in range(cp_size)] rng_states = [None for _ in range(cp_size)] @@ -683,19 +1361,15 @@ def forward( fwd_results_correction_done = torch.cuda.Event() p2p_comm_buffers = [None for _ in range(cp_size)] - if enable_mla: - # If MLA, the shape of k and v does not match, so we flatten them - # and split them after receiving them. - k_shape = k.shape - k_numel = k.numel() - v_shape = v.shape - p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) - elif qkv_format in ["bshd", "sbhd"]: - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(-3), v.unsqueeze(-3)), dim=-3) - else: # qkv_format == "thd" - p2p_comm_buffers[0] = torch.cat((k.unsqueeze(0), v.unsqueeze(0)), dim=0) + k_shape = k.shape + k_numel = k.numel() + v_shape = v.shape + p2p_comm_buffers[0] = torch.cat((k.view(-1), v.view(-1)), dim=-1) send_recv_reqs = [[], []] + # P2P communication and compute: each rank has cp_size steps + # f16 attention: q, k, v: torch.Tensor, dtype=fwd_nominal_dtype + # fp8 attention: q, k, v: torch.Tensor, dtype=torch.uint8 out = None for i in range(cp_size + 1): if i < cp_size: @@ -716,634 +1390,210 @@ def forward( batch_p2p_comm, ) - if not fp8 or is_input_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - kv_inputs[i % 2] = p2p_comm_buffers[i] + kv_inputs[i % 2] = p2p_comm_buffers[i] + k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) + v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + q_part = q + + prepare_inputs = [ + q_part, + k_part, + v_part, + qkv_format, + pad_between_seqs, + cu_seqlens_q, + cu_seqlens_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + cu_seqlens_q_half, + cu_seqlens_kv_half, + rank, + i, + cp_size, + ] + if use_fused_attention: + fused_attn_inputs = [ + attn_bias, + attn_bias_, + is_training, + max_seqlen_q, + max_seqlen_kv, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + softmax_scale, + dropout_p, + qkv_layout, + attn_mask_type, + attn_bias_type, + fp8, + q_fp8, + k_fp8, + v_fp8, + fwd_nominal_dtype, + S_quantizer_per_step[i], + O_quantizer_per_step[i], + rank, + i, + cp_size, + return_max_logit, + ] else: - # KV exchange is in BF16/FP16, cast received KV in each step - kv_inputs[i % 2] = QKV_quantizer(p2p_comm_buffers[i])._data - if enable_mla: - # If MLA, k and v are flattened, so split them after receiving. - k_part = kv_inputs[i % 2][:k_numel].view(*k_shape) - v_part = kv_inputs[i % 2][k_numel:].view(*v_shape) + flash_attn_inputs = [ + use_flash_attn_3, + qkv_format, + fa_forward_kwargs, + flash_attn_fwd, + max_seqlen_q, + max_seqlen_kv, + ] + + # cp_size = 4: + # + # step + # section | 0 1 2 3 + # -------------------- + # G 0 | d, u, u, u, + # P 1 | l, d, u, u, + # U 2 | l, l, d, u, + # 3 | l, l, l, d, + # + # Each GPU holds a slice of Q and KV. To compute the attention of each Q slice, each GPU + # runs cp_size steps to get the partial results of its own Q and all KV slices. KV is communicated + # in a point-to-point, ring fashion. For attn_mask_type = causal, there are three attention + # patterns in the cp_size x cp_size (i.e. GPU x step) matrix, the diagonal tiles, the lower-triangle + # tiles, and the upper-triangle tiles. For attn_mask_type != causal, the pattern is all the same. if causal: if i == 0: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, cu_seqlens_kv_padded, cp_size, rank, True, True - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - q_inputs[i % 2] = q + section = "diagonal" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - fake_dtype=qkv_dtype, - fused_attention_backend=fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + max_logit_per_step[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=True, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] elif i <= rank: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - False, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // (cp_size * 2) - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv_half - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_inputs[i % 2] = q.view(q.shape[0], -1, *q.shape[-2:]) - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk//2, np, hn] - k_part = k_part[:, 0, ...] - v_part = v_part[:, 0, ...] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][:, 0, ...] - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_inputs[i % 2] = q.view(-1, *q.shape[-3:]) - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk//2, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2][0] - elif qkv_format == "thd": - q_inputs[i % 2] = q - if enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor( - k_part, cu_seqlens_kv_padded, 0 - ) - v_part = tex.thd_read_half_tensor( - v_part, cu_seqlens_kv_padded, 0 - ) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_inputs[i % 2] = tex.thd_read_half_tensor( - kv_inputs[i % 2], cu_seqlens_kv_padded, 0 - ) + section = "lower-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_inputs[i % 2] = kv_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = attn_bias[..., idx, :].contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv // 2, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None - if cu_seqlens_kv_padded is None - else cu_seqlens_kv_padded // 2 - ), - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + max_logit_per_step[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv // 2, - ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, False, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // (cp_size * 2) - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q_half - cu_seqlens_kv_per_step[i] = cu_seqlens_kv - if qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_inputs[i % 2] = q[:, 1, ...] - if enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - k.shape[0], -1, 2, *k.shape[-2:] - ) - elif qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_inputs[i % 2] = q[1] - if enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[2:]) - v_part = v_part.view(-1, *v_part.shape[2:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_inputs[i % 2] = kv_inputs[i % 2].view( - -1, k.shape[2], 2, *k.shape[-2:] - ) - elif qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_inputs[i % 2] = tex.thd_read_half_tensor( - q, cu_seqlens_q_padded, 1 - ) + section = "upper-triangle" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - q_inputs[i % 2] = q_inputs[i % 2].contiguous() - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias_[..., 1, :, idx, :], - attn_bias_[..., 1, :, (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q_inputs[i % 2] - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q // 2, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=( - None - if cu_seqlens_q_padded is None - else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + max_logit_per_step[i], + ) = cp_p2p_fwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q // 2, - max_seqlen_kv=max_seqlen_kv, - ) - if use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_forward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_forward_kwargs["window_size_left"] = -1 - fa_forward_kwargs["window_size_right"] = -1 - fa_outputs = flash_attn_fwd( - q_inputs[i % 2], - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] else: - if pad_between_seqs: - cu_seqlens_q_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_q, cu_seqlens_q_padded, cp_size, rank, True, True - ) - cu_seqlens_kv_per_step[i] = get_cu_seqlens_on_cp_rank( - cu_seqlens_kv, - cu_seqlens_kv_padded, - cp_size, - (rank - i) % cp_size, - True, - True, - ) - elif qkv_format == "thd": - cu_seqlens_q_per_step[i] = cu_seqlens_q // cp_size - cu_seqlens_kv_per_step[i] = cu_seqlens_kv // cp_size - else: - cu_seqlens_q_per_step[i] = cu_seqlens_q - cu_seqlens_kv_per_step[i] = cu_seqlens_kv + # all tiles + section = "all" + prepare_outputs = cp_p2p_fwd_prepare_qkv(*prepare_inputs, section) + ( + q_part, + k_part, + v_part, + cu_seqlens_q_per_step[i], + cu_seqlens_kv_per_step[i], + ) = prepare_outputs + q_inputs[i % 2] = q_part if use_fused_attention: - if attn_bias is not None: - idx = (rank - i) % cp_size - attn_bias_inputs[i % 2] = torch.cat( - ( - attn_bias[..., idx, :], - attn_bias[..., (2 * cp_size - idx - 1), :], - ), - dim=-1, - ).contiguous() - - q_part = q - if not enable_mla: - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fp8_meta_kwargs = {} - if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=qkv_dtype, internal=True - ) - fp8_meta_kwargs["s_quantizer"] = S_quantizer_per_step[i] - fp8_meta_kwargs["o_quantizer"] = O_CP_quantizer_per_step[i] - out_per_step[i], aux_ctx_tensors = fused_attn_fwd( - is_training, - max_seqlen_q, - max_seqlen_kv, - cu_seqlens_q_per_step[i], - cu_seqlens_kv_per_step[i], - q_part, - k_part, - v_part, - qkv_dtype, - fused_attn_backend, - attn_scale=softmax_scale, - dropout=dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=attn_mask_type, - attn_bias_type=attn_bias_type, - attn_bias=attn_bias_inputs[i % 2], - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - **fp8_meta_kwargs, - ) - if fp8: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors - attn_biases[i] = rest[0] if len(rest) > 0 else None + ( + out_per_step[i], + softmax_lse_per_step[i], + rng_states[i], + attn_biases[i], + max_logit_per_step[i], + ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: - if not enable_mla: - # If MHA, then split the KV into k_part and v_part. - # Otherwise (MHA), k_part and v_part have already been split. - k_part = ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ) - v_part = ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ) - fa_forward_args_thd = get_fa_args( - True, - use_flash_attn_3, - qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[i], - cu_seqlens_kv=cu_seqlens_kv_per_step[i], - max_seqlen_q=max_seqlen_q, - max_seqlen_kv=max_seqlen_kv, - ) - fa_outputs = flash_attn_fwd( - q, - k_part, - v_part, - *fa_forward_args_thd, - causal=False, - **fa_forward_kwargs, + out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( + cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) ) - if not fa_utils.v2_7_0_plus: - out_per_step[i] = fa_outputs[4] - softmax_lse_per_step[i] = fa_outputs[5] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[7] - else: - out_per_step[i] = fa_outputs[0] - softmax_lse_per_step[i] = fa_outputs[1] - if not use_flash_attn_3: - rng_states[i] = fa_outputs[3] + # softmax_lse correction if i > 0: - # wait until fwd restuls correction of last step is done + # wait until fwd results correction of last step is done if i > 1: flash_attn_streams[(i - 1) % 2].wait_event(fwd_results_correction_done) with torch.cuda.stream(flash_attn_streams[(i - 1) % 2]): if use_fused_attention: - # [b, np, sq, 1] -> [b, np, sq] or - # [t, np, 1] -> [t, np] + # [b, h, sq, 1] -> [b, h, sq] or + # [t, h, 1] -> [t, np] softmax_lse_per_step[i - 1].squeeze_(-1) if softmax_lse_in_packed_format: softmax_lse_per_step[i - 1] = ( softmax_lse_per_step[i - 1].transpose(0, 1).contiguous() ) if fp8: - out_per_step[i - 1] = out_per_step[i - 1].dequantize(dtype=torch.float32) + # dequantize out_per_step to torch.float32 + if fp8_recipe.delayed(): + out_per_step[i - 1] = out_per_step[i - 1].dequantize( + dtype=torch.float32 + ) + if fp8_recipe.float8_current_scaling(): + out_per_step[i - 1] = out_per_step[i - 1].to(dtype=torch.float32) + if i == 1: softmax_lse = torch.clone(softmax_lse_per_step[0]) if qkv_format == "thd": @@ -1373,16 +1623,26 @@ def forward( softmax_lse.view(*softmax_lse.shape[:-1], 2, -1), softmax_lse_per_step[i - 1], ) + if return_max_logit: + if i == 1: + max_logit = torch.clone(max_logit_per_step[0]) + else: + max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) if i < cp_size: flash_attn_streams[(i - 1) % 2].record_event(fwd_results_correction_done) torch.cuda.current_stream().wait_stream(flash_attn_streams[1]) + if return_max_logit: + torch.distributed.all_reduce( + max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) second_half_lse_seqlen = None if causal and rank < (cp_size - 1): second_half_lse_seqlen = softmax_lse_per_step[-1].shape[-1] + # fwd output correction: out in torch.float32 for i in range(cp_size): if i <= rank or not causal: if qkv_format in ["bshd", "sbhd"]: @@ -1435,7 +1695,6 @@ def forward( softmax_lse_in_packed_format, ) - kv = p2p_comm_buffers[-1] if qkv_format == "bshd": out = out.view(out.shape[0], -1, *out.shape[-2:]) ctx.batch_size = out.shape[0] @@ -1450,39 +1709,88 @@ def forward( ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] + # [b*s, h, d] -> [b, s, h, d] out = out.view(ctx.batch_size, -1, *out.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] + # [s*b, h, d] -> [s, b, h, d] out = out.view(-1, ctx.batch_size, *out.shape[-2:]) + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + max_logit, 0, cp_size_a2a, cp_group_a2a, cp_stream, False + ) elif not use_fused_attention: out = out.view(-1, *out.shape[-2:]) + # update FP8 quantizers: amax across cp_size steps if fp8 and use_fused_attention: amax_cp_fwd = amax_per_step.amax(dim=1) S_quantizer.amax.copy_(amax_cp_fwd[0]) - O_CP_quantizer.amax.copy_(amax_cp_fwd[1]) - - out_fp8 = None - out_f16 = out.to(qkv_dtype) + O_quantizer.amax.copy_(amax_cp_fwd[1]) - if fp8 and (is_output_fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1"))): - out_fp8 = O_quantizer(out_f16) # final result + if fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.forward >> after: ", + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, + ) + # prepare for return and ctx saves + out_fp8 = None + out_f16 = out.to(fwd_nominal_dtype) + if fp8 and ( + is_output_fp8 + or (is_bwd_fp8 and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16)) + ): + out_fp8 = O_quantizer(out_f16) out_ret = out_fp8 if (fp8 and is_output_fp8) else out_f16 - if fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, kv_save, out_save = q, kv, out_fp8._data + ctx.layer_number = layer_number + ctx.fp8_recipe = fp8_recipe + ctx.fp8 = fp8 and is_bwd_fp8 + + kv_fp8 = None + kv = p2p_comm_buffers[-1] + if fp8: + q_fp8, kv_fp8 = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8], [q, kv]) + ] + # q, kv, out + fp8_tensors = (None, None, None) + f16_tensors = (None, None, None) + if ctx.fp8: + # fwd: fp8, bwd: fp8, save all fp8 + fp8_tensors = (q_fp8, kv_fp8, out_fp8) + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + f16_tensors = (None, None, out_f16) elif fp8 and is_input_fp8: - q_save, kv_save, out_save = q, kv, out_f16 + # fwd: fp8, bwd: f16, save all f16 + # dequantize fp8 inputs + q_f16 = q_fp8.dequantize() + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) + elif fp8: + # fwd: fp8, bwd: f16, save all f16 + # inputs are already in f16 + q_f16 = q_f16.view(q.shape) + kv_f16 = kv_fp8.dequantize() + f16_tensors = (q_f16, kv_f16, out_f16) else: + # fwd: f16, bwd: f16, save all f16 + # inputs and kernels are both f16 q_f16 = q_f16.view(q.shape) - q_save, kv_save, out_save = q_f16, kv, out_f16 + kv_f16 = kv + f16_tensors = (q_f16, kv_f16, out_f16) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - kv_save, - out_save, + *fp8_tensors, + *f16_tensors, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, @@ -1512,21 +1820,18 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 ctx.use_flash_attn_3 = use_flash_attn_3 ctx.enable_mla = enable_mla - if enable_mla: - ctx.k_numel = k_numel - ctx.k_shape = k_shape - ctx.v_shape = v_shape + ctx.k_numel = k_numel + ctx.k_shape = k_shape + ctx.v_shape = v_shape - ctx.qkv_dtype = qkv_dtype + ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.dQKV_quantizer = dQKV_quantizer - ctx.dQKV_CP_quantizer = dQKV_CP_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.QKV_quantizer = QKV_quantizer @@ -1539,17 +1844,33 @@ def forward( ctx.O_quantizer.scale = O_quantizer.scale.clone() ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") + nvtx_range_pop(f"{nvtx_label}") + + if return_max_logit: + return out_ret, max_logit return out_ret @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *_args): # pylint: disable=missing-function-docstring - nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + # add NVTX range + nvtx_label = "transformer_engine.AttnFuncWithCPAndKVP2P.backward" + nvtx_range_push(f"{nvtx_label}") + + # dout is expected to be in FP8 if is_output_fp8=True, + # but in the case it's not, convert it to FP8 before any operation + if ctx.fp8 and ctx.is_output_fp8 and not isinstance(dout, QuantizedTensorStorage): + dout = ctx.dO_quantizer(dout) + if ctx.use_fused_attention: + dout._data = dout._data.contiguous() + elif ctx.use_fused_attention: + dout = dout.contiguous() + + # set up CP groups for cp_comm_type = {'p2p', 'a2a+p2p'} cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a - cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) send_dst = ctx.cp_global_ranks[(rank - 1) % cp_size * cp_size_a2a + rank_a2a] @@ -1559,33 +1880,38 @@ def backward(ctx, dout): device_compute_capability < (10, 0) and cp_size == 2 ) - q, kv, out, softmax_lse, cu_seqlens_q_padded, cu_seqlens_kv_padded, *other_tensors = ( - restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - ) + # get saved tensors + ( + q_fp8, + kv_fp8, + out_fp8, + q, + kv, + out, + softmax_lse, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + *other_tensors, + ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) cu_seqlens_q_per_step = other_tensors[:cp_size] cu_seqlens_kv_per_step = other_tensors[cp_size : cp_size * 2] rng_states = other_tensors[cp_size * 2 : cp_size * 3] attn_biases = other_tensors[cp_size * 3 : cp_size * 4] + # set up attention args causal = "causal" in ctx.attn_mask_type - padding = "padding" in ctx.attn_mask_type - seq_dim = None + qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format if ctx.qkv_format in ["bshd", "sbhd"]: seq_dim = ctx.qkv_format.index("s") - if ctx.enable_mla: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format[:-2] + "2" + ctx.qkv_format[-2:] - else: - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format + # set up attention bias if attn_biases[0] is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] attn_dbias = torch.zeros( *ctx.attn_bias_shape, dtype=attn_biases[0].dtype, device=attn_biases[0].device ) - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, 2, sq//2, 2*cp, sk//(2*cp)] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, 2, sq//2, 2*cp, sk//(2*cp)] attn_dbias_ = attn_dbias.view( *attn_dbias.shape[:-3], 2, attn_dbias.shape[-3] // 2, *attn_dbias.shape[-2:] ) @@ -1593,6 +1919,7 @@ def backward(ctx, dout): attn_dbias = None attn_dbias_ = None + # set up softmax_lse softmax_lse_ = None if causal and ctx.second_half_lse_seqlen is not None: if ctx.qkv_format == "thd": @@ -1603,86 +1930,124 @@ def backward(ctx, dout): ctx.second_half_lse_seqlen, ) else: - # [b, np, sq] -> [b, np, 2, sq//2] + # [b, h, sq] -> [b, h, 2, sq//2] softmax_lse_ = softmax_lse.view(*softmax_lse.shape[:-1], 2, -1) softmax_lse_ = softmax_lse_[..., 1, :].contiguous() if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse_ = softmax_lse_.transpose(0, 1).contiguous() - # [b, np, sq//2] -> [b, np, sq//2, 1] or - # [t//2, np] -> [t//2, np, 1] + # [b, h, sq//2] -> [b, h, sq//2, 1] or + # [t//2, np] -> [t//2, h, 1] softmax_lse_.unsqueeze_(-1) if ctx.use_fused_attention: if ctx.softmax_lse_in_packed_format: softmax_lse = softmax_lse.transpose(0, 1).contiguous() - # [b, np, sq] -> [b, np, sq, 1] or - # [t, np] -> [t, np, 1] + # [b, h, sq] -> [b, h, sq, 1] or + # [t, np] -> [t, h, 1] softmax_lse.unsqueeze_(-1) - dout = dout.contiguous() - dq = None - dout_dtype = dout.dtype + # assume fwd and bwd always use the same high precision, i.e. torch.float16 or torch.bfloat16 + # used when some tensors are base tensors and loose the "dtype" attribute + bwd_nominal_dtype = ctx.fwd_nominal_dtype + + # convert out, dout to the right type fused_attn_backend = None - fused_attn_dqkv_dtype = None amax_per_step = None dP_quantizer_per_step = [None for _ in range(cp_size)] - dQKV_CP_quantizer_per_step = [None for _ in range(cp_size)] + dQKV_quantizer_per_step = [None for _ in range(cp_size)] + buffer_dtype = torch.uint8 + dq_buffer = None + dout_fp8 = None + bwd_output_te_dtype = None + dkv_buffer = None if ctx.fp8: - if ctx.use_fused_attention: - fused_attn_backend = FusedAttnBackend["FP8"] + assert ctx.use_fused_attention, "FP8 is only supported with Fused Attention!" + fused_attn_backend = FusedAttnBackend["FP8"] + q, kv, out = ( + q_fp8._data, + kv_fp8._data, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8._data + ), + ) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: - dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] - dq_fp8 = torch.empty((cp_size, *q.shape), dtype=dout._data.dtype, device=q.device) - dkv_fp8 = torch.empty( - (cp_size, *kv.shape), dtype=dout._data.dtype, device=kv.device - ) - dkv_fp8_ = torch.empty_like(dkv_fp8) - p2p_comm_buffers = [[kv, dkv_fp8], [torch.empty_like(kv), dkv_fp8_]] - dout = dout._data - fp8_meta_kwargs = {} - fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer - amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) - for i in range(cp_size): - dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() - dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) - dQKV_CP_quantizer_per_step[i] = ctx.dQKV_CP_quantizer.copy() - dQKV_CP_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) + # dout_fp8: Float8Tensor, dtype=bwd_nominal_dtype + # dout: torch.Tensor, dtype=torch.uint8 + if ctx.is_output_fp8: + dout_fp8 = dout else: - assert False, "FP8 is only supported with Fused Attention!" + dout_fp8 = ctx.dO_quantizer(dout) + dout = dout_fp8._data + + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> before: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) + + # dout_fp8._fp8_dtype + bwd_output_te_dtype = ctx.dO_quantizer.dtype + + # create buffers for reduction in float32 + if ctx.fp8_recipe.delayed(): + dq_buffer = torch.empty( + (cp_size, *q.shape), + dtype=buffer_dtype, + device=q.device, + ) + if ctx.fp8_recipe.float8_current_scaling(): + dq_buffer = torch.empty( + q.shape, + dtype=torch.float32, + device=q.device, + ) + kv_recv_buffer = torch.empty_like(kv) + dkv_send_buffer = torch.empty( + (cp_size, *kv.shape), + dtype=buffer_dtype, + device=kv.device, + ) + dkv_recv_buffer = torch.empty_like(dkv_send_buffer) + p2p_comm_buffers = [[kv, dkv_send_buffer], [kv_recv_buffer, dkv_recv_buffer]] + if ctx.fp8_recipe.float8_current_scaling(): + dkv_buffer = torch.zeros( + kv.shape, + dtype=torch.float32, + device=kv.device, + ) + + # amax_per_step[0]: amax_dp x cp_size + # amax_per_step[1]: amax_dqkv x cp_size + amax_per_step = torch.zeros((2, cp_size), dtype=torch.float32, device=q.device) + # per_step tensors are not reduced even if Float8CurrentScaling.with_amax_reduction=True; + # only used to hold temporary scale/amax values (output only, no quantization op) + for i in range(cp_size): + dP_quantizer_per_step[i] = ctx.dP_quantizer.copy() + dP_quantizer_per_step[i].amax = amax_per_step[0][i].reshape((1,)) + dQKV_quantizer_per_step[i] = ctx.dQKV_quantizer.copy() + dQKV_quantizer_per_step[i].amax = amax_per_step[1][i].reshape((1,)) else: - if ctx.fp8_meta is not None: - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - kv = ctx.QKV_quantizer.create_tensor_from_data( - kv, fake_dtype=ctx.qkv_dtype, internal=True - ) - q = q.dequantize(dtype=ctx.qkv_dtype) - kv = kv.dequantize(dtype=ctx.qkv_dtype) - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - if cp_size_a2a == 1: - dout = dout.dequantize(dtype=dout_dtype) - else: - ctx.dO_quantizer = dout._quantizer - dout = dout._data - dq = torch.empty_like(q) + if isinstance(dout, QuantizedTensorStorage): + dout = dout.dequantize(dtype=bwd_nominal_dtype) + dq_buffer = torch.empty_like(q) p2p_comm_buffers = [ torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), torch.empty((2, *kv.shape), dtype=kv.dtype, device=kv.device), ] p2p_comm_buffers[0][0].copy_(kv) if ctx.use_fused_attention: - fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + bwd_output_te_dtype = TE_DType[bwd_nominal_dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] + # communicate for the 'a2a' part of 'a2a+p2p' if cp_size_a2a > 1: if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) @@ -1699,11 +2064,6 @@ def backward(ctx, dout): ctx.cp_stream, True, ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - dout = dout.dequantize(dtype=dout_dtype) if ctx.enable_mla: out = out.view(*ctx.v_shape) @@ -1712,7 +2072,6 @@ def backward(ctx, dout): # MHA or GQA out = out.view(*q.shape) dout = dout.view(*q.shape) - send_recv_reqs = [] flash_attn_bwd = None if not ctx.use_fused_attention: @@ -1747,6 +2106,7 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + send_recv_reqs = [] for i in range(cp_size): # wait until KV is received for req in send_recv_reqs: @@ -1767,8 +2127,8 @@ def backward(ctx, dout): ) else: dkv_a2a_req = torch.distributed.all_to_all_single( - dkv_fp8, - dkv_fp8_, + dkv_send_buffer, + dkv_recv_buffer, group=ctx.cp_group, async_op=True, ) @@ -1785,593 +2145,146 @@ def backward(ctx, dout): ) kv = p2p_comm_buffers[i % 2][0] - q_, kv_, out_, dout_ = None, None, None, None dq_, dk_, dv_ = None, None, None - if ctx.enable_mla: - k_part = kv[: ctx.k_numel].view(*ctx.k_shape) - v_part = kv[ctx.k_numel :].view(*ctx.v_shape) - # In reversed order of fwd + k_part = kv[: ctx.k_numel].view(*ctx.k_shape) + v_part = kv[ctx.k_numel :].view(*ctx.v_shape) + q_part, out_part, dout_part = q, out, dout + + prepare_inputs = [ + q_part, + k_part, + v_part, + out_part, + dout_part, + ctx.qkv_format, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + ] + if ctx.use_fused_attention: + fused_attn_inputs = [ + ctx.fp8, + ctx.fp8_recipe, + q_fp8, + kv_fp8, + ( + out + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + else out_fp8 + ), + dout_fp8, + softmax_lse, + softmax_lse_, + rng_states, + attn_dbias, + attn_biases, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + i, + cp_size, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, + fused_attn_backend, + ctx.softmax_scale, + ctx.dropout_p, + qkv_layout, + ctx.attn_mask_type, + ctx.attn_bias_type, + ctx.deterministic, + ctx.fwd_nominal_dtype, + bwd_nominal_dtype, + bwd_output_te_dtype, + ctx.S_quantizer, + dP_quantizer_per_step[i], + dQKV_quantizer_per_step[i], + ] + else: + flash_attn_inputs = [ + ctx.use_flash_attn_3, + ctx.qkv_format, + ctx.max_seqlen_q, + ctx.max_seqlen_kv, + cu_seqlens_q_per_step, + cu_seqlens_kv_per_step, + i, + cp_size, + fa_backward_kwargs, + flash_attn_bwd, + rng_states, + softmax_lse, + softmax_lse_, + ] + + # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, + # there are still three sections in these tiles based on their attention pattern + # for attn_mask_type = causal, and one for attn_mask_type != causal. if causal: if i == (cp_size - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - q_, kv_, out_, dout_ = q, kv, out, dout + section = "diagonal" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, 0) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = 0 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=True, - **fa_backward_kwargs, - ) - elif i >= (cp_size - rank - 1): - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - q_, out_, dout_ = [ - x.view(x.shape[0], -1, *x.shape[-2:]) for x in [q, out, dout] - ] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part[:, 0] - v_part = v_part[:, 0] - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk//2, 2, np, hn] - kv_ = kv[:, 0] - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - q_, out_, dout_ = [x.view(-1, *x.shape[-3:]) for x in [q, out, dout]] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part[0] - v_part = v_part[0] - else: - # [2, sk//2, b, 2, np, hn] -> [sk//2, b, 2, np, hn] - kv_ = kv[0] - elif ctx.qkv_format == "thd": - q_, out_, dout_ = q, out, dout - if ctx.enable_mla: - # [t, np, hn] -> [t/2, np, hn] - k_part = tex.thd_read_half_tensor(k_part, cu_seqlens_kv_padded, 0) - v_part = tex.thd_read_half_tensor(v_part, cu_seqlens_kv_padded, 0) - else: - # [2, t, np, hn] -> [2, t/2, np, hn] - kv_ = tex.thd_read_half_tensor(kv, cu_seqlens_kv_padded, 0) - if ctx.use_fused_attention: - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - else: - kv_ = kv_.contiguous() - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv // 2, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=( - None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded // 2 - ), - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + elif i >= (cp_size - rank - 1): + section = "lower-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) + if ctx.use_fused_attention: + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - k_part = k_part.contiguous() - v_part = v_part.contiguous() - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv // 2, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - q_, out_, dout_ = q[:, 1], out[:, 1], dout[:, 1] - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - k_part = k_part.view(k_part.shape[0], -1, *k_part.shape[-2:]) - v_part = v_part.view(v_part.shape[0], -1, *v_part.shape[-2:]) - else: - # [b, 2, sk//2, 2, np, hn] -> [b, sk, 2, np, hn] - kv_ = kv.view(kv.shape[0], -1, *kv.shape[-3:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] - q_, out_, dout_ = q[1], out[1], dout[1] - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - k_part = k_part.view(-1, *k_part.shape[-3:]) - v_part = v_part.view(-1, *v_part.shape[-3:]) - else: - # [2, sk//2, b, 2, np, hn] -> [sk, b, 2, np, hn] - kv_ = kv.view(-1, *kv.shape[-4:]) - elif ctx.qkv_format == "thd": - # [t, np, hn] -> [t/2, np, hn] - q_, out_, dout_ = [ - tex.thd_read_half_tensor(x, cu_seqlens_q_padded, 1) - for x in [q, out, dout] - ] - kv_ = kv + section = "upper-triangle" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - q_, out_, dout_ = [x.contiguous() for x in [q_, out_, dout_]] - if ctx.fp8: - aux_ctx_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - i - 1], - ] - else: - aux_ctx_tensors = [softmax_lse_, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - - q_part = q_ - if not ctx.enable_mla: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - out_part = out_ - dout_part = dout_ - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q // 2, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=( - None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 - ), - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type="padding" if padding else "no_mask", - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data else: - dq_ = torch.empty_like(q_) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = ( - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] - ) - v_part = ( - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] - ) - dkv_ = torch.empty_like(kv_) - dk_ = ( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ) - dv_ = ( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ) - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q // 2, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or ( - fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus - ): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout_, - q_, - k_part, - v_part, - out_, - softmax_lse_, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) else: + section = "all" + prepare_outputs = cp_p2p_bwd_prepare_qkv(*prepare_inputs, section) if ctx.use_fused_attention: - if ctx.fp8: - aux_ctx_tensors = [softmax_lse, softmax_lse, rng_states[cp_size - i - 1]] - else: - aux_ctx_tensors = [softmax_lse, rng_states[cp_size - i - 1]] - if attn_dbias is not None: - aux_ctx_tensors += [attn_biases[cp_size - i - 1]] - q_part = q - if not ctx.enable_mla: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - out_part = out - dout_part = dout - - if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - fp8_meta_kwargs["dp_quantizer"] = dP_quantizer_per_step[i] - fp8_meta_kwargs["dqkv_quantizer"] = dQKV_CP_quantizer_per_step[i] - dq_, dk_, dv_, dbias_ = fused_attn_bwd( - ctx.max_seqlen_q, - ctx.max_seqlen_kv, - cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv_per_step[cp_size - i - 1], - q_part, - k_part, - v_part, - out_part, - dout_part, - dout_dtype, - fused_attn_dqkv_dtype, - aux_ctx_tensors, - fused_attn_backend, - cu_seqlens_q_padded=cu_seqlens_q_padded, - cu_seqlens_kv_padded=cu_seqlens_kv_padded, - attn_scale=ctx.softmax_scale, - dropout=ctx.dropout_p, - qkv_layout=qkv_layout, - attn_mask_type=ctx.attn_mask_type, - attn_bias_type=ctx.attn_bias_type, - deterministic=ctx.deterministic, - **fp8_meta_kwargs, + dq_, dk_, dv_, dbias_ = cp_p2p_bwd_fused_attn( + *fused_attn_inputs, *prepare_outputs, section ) - - if ctx.fp8: - dq_ = dq_._data - dk_ = dk_._data - dv_ = dv_._data - else: - dq_ = torch.empty_like(q) - if ctx.enable_mla: - dk_ = torch.empty_like(k_part) - dv_ = torch.empty_like(v_part) - else: - k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] - v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] - dkv_ = torch.empty_like(kv) - dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] - dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] - fa_backward_args_thd = get_fa_args( - False, - ctx.use_flash_attn_3, - ctx.qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - i - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - i - 1], - max_seqlen_q=ctx.max_seqlen_q, - max_seqlen_kv=ctx.max_seqlen_kv, - dq=dq_, - dk=dk_, - dv=dv_, - ) - if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): - fa_backward_kwargs["window_size"] = (-1, -1) - elif fa_utils.v2_7_0_plus: - fa_backward_kwargs["window_size_left"] = -1 - fa_backward_kwargs["window_size_right"] = -1 - if not ctx.use_flash_attn_3: - fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - flash_attn_bwd( - dout, - q, - k_part, - v_part, - out, - softmax_lse, - *fa_backward_args_thd, - causal=False, - **fa_backward_kwargs, + dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( + *flash_attn_inputs, *prepare_outputs, section ) - if ctx.fp8: - dq = dq_fp8[(rank + i + 1) % cp_size] + # dq, dk, dv are reduced across steps in higher precision + # DelayedScaling: collect all results in uint8 to one tensor, dequantize to float32, then reduce + # CurrentScaling: dequantize partial results from each step to float32, then reduce + if ctx.fp8 and ctx.use_fused_attention: + if ctx.fp8_recipe.delayed(): + dq_, dk_, dv_ = [x._data for x in [dq_, dk_, dv_]] + if ctx.fp8_recipe.float8_current_scaling(): + dq_, dk_, dv_ = [x.to(torch.float32) for x in [dq_, dk_, dv_]] + + # copy dq_ into the right buffer position + # buffer is cp_size x dq_size for DelayedScaling and the same size as dq for CurrentScaling + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dq = dq_buffer[(rank + i + 1) % cp_size] + else: + dq = dq_buffer if causal and ctx.qkv_format in ["bshd", "sbhd"] and i >= (cp_size - rank - 1): - # [b, sq, np, hn] -> [b, 2, sq//2, np, hn] or - # [sq, b, np, hn] -> [2, sq//2, b, np, hn] + # [b, sq, h, d] -> [b, 2, sq//2, h, d] or + # [sq, b, h, d] -> [2, sq//2, b, h, d] dq_ = dq_.view(*dq.shape) - - if ctx.fp8: + if ctx.fp8 and ctx.fp8_recipe.delayed(): if i >= (cp_size - rank - 1) or not causal: dq.copy_(dq_) else: @@ -2381,6 +2294,8 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq[0].fill_(0) dq[1].copy_(dq_) + else: + dq.copy_(dq_) elif causal: if i > (cp_size - rank - 1): dq.add_(dq_) @@ -2416,18 +2331,19 @@ def backward(ctx, dout): else: dq.add_(dq_) + # dbias correction if attn_dbias is not None: idx = (rank + i + 1) % cp_size if i == (cp_size - 1) or not causal: - # [b, np, sq, sk//cp] -> [b, np, sq, 2, sk//(2*cp)] + # [b, h, sq, sk//cp] -> [b, h, sq, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias[..., idx, :].copy_(dbias_[..., 0, :]) attn_dbias[..., (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) elif i >= (cp_size - rank - 1): - # [b, np, sq, sk//(2*cp)] + # [b, h, sq, sk//(2*cp)] attn_dbias[..., idx, :].copy_(dbias_) else: - # [b, np, sq//2, sk//cp] -> [b, np, sq//2, 2, sk//(2*cp)] + # [b, h, sq//2, sk//cp] -> [b, h, sq//2, 2, sk//(2*cp)] dbias_ = dbias_.view(*dbias_.shape[:-1], 2, dbias_.shape[-1] // 2) attn_dbias_[..., 1, :, idx, :].copy_(dbias_[..., 0, :]) attn_dbias_[..., 1, :, (2 * cp_size - idx - 1), :].copy_(dbias_[..., 1, :]) @@ -2436,254 +2352,159 @@ def backward(ctx, dout): for req in send_recv_reqs: req.wait() - if ctx.fp8: - if i < cp_size - 1: - dkv = dkv_fp8_[(rank + i + 1) % cp_size] - else: - dkv = dkv_fp8[(rank + i + 1) % cp_size] + # dkv correction + if ctx.fp8 and ctx.fp8_recipe.delayed(): + dkv = dkv_recv_buffer[(rank + i + 1) % cp_size] + elif ctx.fp8 and ctx.fp8_recipe.float8_current_scaling(): + dkv = dkv_buffer else: dkv = p2p_comm_buffers[(i + 1) % 2][1] - if ctx.use_fused_attention: - if ctx.enable_mla: - dkv_ = None - elif ctx.qkv_format in ["bshd", "sbhd"]: - dkv_ = combine_tensors([dk_, dv_], -2) - elif ctx.qkv_format == "thd": - dkv_ = torch.cat( - (dk_.unsqueeze(0), dv_.unsqueeze(0)), dim=0 - ) # pylint: disable=used-before-assignment - if not ctx.enable_mla and ctx.qkv_format in ["bshd", "sbhd"]: - # [b, 2, sk//2, 2, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk//2, b, 2, np, hn] -> [2, 2, sk//2, b, np, hn] - # dkv is a buffer, so we do not need to transpose it, but only need to reshape it. - dkv = dkv.view(2, *dkv.shape[0:-3], *dkv.shape[-2:]) - dkv_ = dkv_.movedim(-3, 0) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - # [2, b, sk, np, hn] -> [2, b, 2, sk//2, np, hn] or - # [2, sk, b, np, hn] -> [2, 2, sk//2, b, np, hn] - dkv_ = dkv_.view(*dkv.shape) - - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] or - # [2, sk//2, b, np, hn] - dk = dkv[: ctx.k_numel].view(*ctx.k_shape) - dv = dkv[ctx.k_numel :].view(*ctx.v_shape) - if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): - dk_ = dk_.view(*ctx.k_shape) - dv_ = dv_.view(*ctx.v_shape) - - if ctx.fp8: - # enable_mla and fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dk[:, 1, ...].fill_(0) - dv[:, 0, ...].copy_(dv_) - dv[:, 1, ...].fill_(0) - elif ctx.qkv_format == "sbhd": - dk[0].copy_(dk_) - dk[1].fill_(0) - dv[0].copy_(dv_) - dv[1].fill_(0) - else: - dk.copy_(dk_) - dv.copy_(dv_) - elif causal: - # enable_mla and not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_[:, 0, ...]) - dk[:, 1, ...].copy_(dk_[:, 1, ...]) - dv[:, 0, ...].add_(dv_[:, 0, ...]) - dv[:, 1, ...].copy_(dv_[:, 1, ...]) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_[0, ...]) - dk[1, ...].copy_(dk_[1, ...]) - dv[0, ...].add_(dv_[0, ...]) - dv[1, ...].copy_(dv_[1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "copy" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dk.add_(dk_) - dv.add_(dv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dk[:, 0, ...].copy_(dk_) - dv[:, 0, ...].copy_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].copy_(dk_) - dv[0, ...].copy_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "copy", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dk[:, 0, ...].add_(dk_) - dv[:, 0, ...].add_(dv_) - elif ctx.qkv_format == "sbhd": - dk[0, ...].add_(dk_) - dv[0, ...].add_(dv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dk, dk_, cu_seqlens_kv_padded, "add", "none" - ) - tex.thd_grad_correction( - dv, dv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dk.add_(dk_) - dv.add_(dv_) - else: # i == 0 + + # [b, 2, sk//2, h, d] or + # [2, sk//2, b, h, d] + dk = dkv[: ctx.k_numel].view(*ctx.k_shape) + dv = dkv[ctx.k_numel :].view(*ctx.v_shape) + if causal and (i < (cp_size - rank - 1) or i == (cp_size - 1)): + dk_ = dk_.view(*ctx.k_shape) + dv_ = dv_.view(*ctx.v_shape) + + if ctx.fp8 and ctx.fp8_recipe.delayed(): + # fp8 + if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + if ctx.qkv_format == "bshd": + dk[:, 0, ...].copy_(dk_) + dk[:, 1, ...].fill_(0) + dv[:, 0, ...].copy_(dv_) + dv[:, 1, ...].fill_(0) + elif ctx.qkv_format == "sbhd": + dk[0].copy_(dk_) + dk[1].fill_(0) + dv[0].copy_(dv_) + dv[1].fill_(0) + else: dk.copy_(dk_) dv.copy_(dv_) else: - # enable_mla and not fp8 and not causal - if i == 0: - dk.copy_(dk_) - dv.copy_(dv_) - else: # i > 0 + dk.copy_(dk_) + dv.copy_(dv_) + elif causal: + # not fp8 and causal + if i == (cp_size - 1): + if rank == 0: + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_[:, 0, ...]) + dk[:, 1, ...].copy_(dk_[:, 1, ...]) + dv[:, 0, ...].add_(dv_[:, 0, ...]) + dv[:, 1, ...].copy_(dv_[:, 1, ...]) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_[0, ...]) + dk[1, ...].copy_(dk_[1, ...]) + dv[0, ...].add_(dv_[0, ...]) + dv[1, ...].copy_(dv_[1, ...]) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "copy") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "copy") + else: dk.add_(dk_) dv.add_(dv_) - else: - if ctx.fp8: - # fp8 - if causal and i >= (cp_size - rank - 1) and i != (cp_size - 1): + elif i >= (cp_size - rank - 1): + if i == 0 and rank == (cp_size - 1): if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - dkv[:, :, 1, ...].fill_(0) + dk[:, 0, ...].copy_(dk_) + dv[:, 0, ...].copy_(dv_) elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - dkv[:, 1, ...].fill_(0) + dk[0, ...].copy_(dk_) + dv[0, ...].copy_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "copy", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "copy", "none") else: - dkv.copy_(dkv_) - elif causal: - # not fp8 and causal - if i == (cp_size - 1): - if rank == 0: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_[:, :, 0, ...]) - dkv[:, :, 1, ...].copy_(dkv_[:, :, 1, ...]) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_[:, 0, ...]) - dkv[:, 1, ...].copy_(dkv_[:, 1, ...]) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "copy" - ) - else: - dkv.add_(dkv_) - elif i >= (cp_size - rank - 1): - if i == 0 and rank == (cp_size - 1): - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].copy_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "copy", "none" - ) - else: - if ctx.qkv_format == "bshd": - dkv[:, :, 0, ...].add_(dkv_) - elif ctx.qkv_format == "sbhd": - dkv[:, 0, ...].add_(dkv_) - elif ctx.qkv_format == "thd": - tex.thd_grad_correction( - dkv, dkv_, cu_seqlens_kv_padded, "add", "none" - ) - elif i > 0: - dkv.add_(dkv_) - else: # i == 0 - dkv.copy_(dkv_) - else: - # not fp8 and not causal - if i == 0: - dkv.copy_(dkv_) - else: # i > 0 - dkv.add_(dkv_) + if ctx.qkv_format == "bshd": + dk[:, 0, ...].add_(dk_) + dv[:, 0, ...].add_(dv_) + elif ctx.qkv_format == "sbhd": + dk[0, ...].add_(dk_) + dv[0, ...].add_(dv_) + elif ctx.qkv_format == "thd": + tex.thd_grad_correction(dk, dk_, cu_seqlens_kv_padded, "add", "none") + tex.thd_grad_correction(dv, dv_, cu_seqlens_kv_padded, "add", "none") + elif i > 0: + dk.add_(dk_) + dv.add_(dv_) + else: # i == 0 + dk.copy_(dk_) + dv.copy_(dv_) + else: + # not fp8 and not causal + if i == 0: + dk.copy_(dk_) + dv.copy_(dv_) + else: # i > 0 + dk.add_(dk_) + dv.add_(dv_) + # sum up all cp_size for dq, dk, dv if ctx.fp8 and ctx.use_fused_attention: amax_cp_bwd = amax_per_step.amax(dim=1) ctx.dP_quantizer.amax.copy_(amax_cp_bwd[0]) - ctx.dQKV_CP_quantizer.amax.copy_(amax_cp_bwd[1]) - dq = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dq_fp8, fake_dtype=torch.float32, internal=True - ) - - if ctx.enable_mla: - # [cp, b, 2, sk//2, np, hn] or [cp, 2, sk//2, b, np, hn] - dk_fp8 = dkv_fp8[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) - dv_fp8 = dkv_fp8[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) - dk = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dk_fp8, fake_dtype=torch.float32, internal=True - ) - dv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dv_fp8, fake_dtype=torch.float32, internal=True - ) - dq, dk, dv = [x.dequantize(dtype=torch.float32) for x in [dq, dk, dv]] - dq, dk, dv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dk, dv]] - else: - if ctx.qkv_format in ["bshd", "sbhd"]: - # [cp, b, 2, sk//2, 2, np, hn] -> [cp, 2, b, 2, sk//2, np, hn] or - # [cp, 2, sk//2, b, 2, np, hn] -> [cp, 2, 2, sk//2, b, np, hn] - dkv_fp8 = dkv_fp8.view(cp_size, 2, *dkv_fp8.shape[1:-3], *dkv_fp8.shape[-2:]) - dkv = ctx.dQKV_CP_quantizer.create_tensor_from_data( - dkv_fp8, fake_dtype=torch.float32, internal=True + ctx.dQKV_quantizer.amax.copy_(amax_cp_bwd[1]) + + dq = dq_buffer + if ctx.fp8_recipe.delayed(): + # [cp, b, 2, sk//2, h, d] or [cp, 2, sk//2, b, h, d] + dk = dkv_recv_buffer[:, : ctx.k_numel].view(cp_size, *ctx.k_shape) + dv = dkv_recv_buffer[:, ctx.k_numel :].view(cp_size, *ctx.v_shape) + dq, dk, dv = [ + ctx.dQKV_quantizer.create_tensor_from_data( + x, fake_dtype=bwd_nominal_dtype, internal=ctx.dQKV_quantizer.internal + ) + for x in [dq, dk, dv] + ] + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + des_nominal_dtype=torch.float32, ) - dq, dkv = [x.dequantize(dtype=torch.float32) for x in [dq, dkv]] - dq, dkv = [x.sum(dim=0).to(dout_dtype) for x in [dq, dkv]] + dq, dk, dv = [x.sum(dim=0).to(bwd_nominal_dtype) for x in [dq, dk, dv]] - if causal: - if ctx.qkv_format == "bshd": - # [b, 2, sq//2, np, hn] -> [b, sq, np, hn] - dq = dq.view(dq.shape[0], -1, *dq.shape[-2:]) - if ctx.enable_mla: - # [b, 2, sk//2, np, hn] -> [b, sk, np, hn] - dk = dk.view(dk.shape[0], -1, *dk.shape[-2:]) - dv = dv.view(dv.shape[0], -1, *dv.shape[-2:]) - else: - # [2, b, 2, sk//2, np, hn] -> [2, b, sk, np, hn] - dkv = dkv.view(*dkv.shape[0:2], -1, *dkv.shape[-2:]) - elif ctx.qkv_format == "sbhd": - # [2, sq//2, b, np, hn] -> [sq, b, np, hn] - dq = dq.view(-1, *dq.shape[-3:]) - if ctx.enable_mla: - # [2, sk//2, b, np, hn] -> [sk, b, np, hn] - dk = dk.view(-1, *dk.shape[-3:]) - dv = dv.view(-1, *dv.shape[-3:]) - else: - # [2, 2, sk//2, b, np, hn] -> [2, sk, b, np, hn] - dkv = dkv.view(dkv.shape[0], -1, *dkv.shape[-3:]) + if ctx.fp8_recipe.float8_current_scaling(): + dk = dkv[: ctx.k_numel].view(ctx.k_shape) + dv = dkv[ctx.k_numel :].view(ctx.v_shape) + + if causal and ctx.qkv_format in ["bshd", "sbhd"]: + # [b, 2, s//2, h, d] -> [b, s, h, d] + # [2, s//2, b, h, d] -> [s, b, h, d] + dim = ctx.qkv_format.index("s") + dq, dk, dv = [x.view(*x.shape[:dim], -1, *x.shape[dim + 2 :]) for x in [dq, dk, dv]] if ctx.qkv_format == "thd" and not ctx.use_fused_attention: dq[cu_seqlens_q_padded[-1] :].fill_(0) - if ctx.enable_mla: - dk[cu_seqlens_kv_padded[-1] :].fill_(0) - dv[cu_seqlens_kv_padded[-1] :].fill_(0) - else: - dkv[:, cu_seqlens_kv_padded[-1] :].fill_(0) + dk[cu_seqlens_kv_padded[-1] :].fill_(0) + dv[cu_seqlens_kv_padded[-1] :].fill_(0) if ctx.fp8 and ctx.is_input_fp8: - assert torch.uint8 not in [dq.dtype, dkv.dtype] - if ctx.enable_mla: - dq, dk, dv = [ctx.dQKV_quantizer(x)._data for x in [dq, dk, dv]] - else: - dq, dkv = [ctx.dQKV_quantizer(x)._data for x in [dq, dkv]] - if not ctx.enable_mla: - dk, dv = dkv[0], dkv[1] + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + + if ctx.fp8: + # print quantizers + print_quantizers( + "AttnFuncWithCPAndKVP2P.backward >> after: ", + ctx.layer_number, + ctx.QKV_quantizer, + ctx.O_quantizer, + ctx.S_quantizer, + ctx.dQKV_quantizer, + ctx.dO_quantizer, + ctx.dP_quantizer, + ) if cp_size_a2a > 1: + if ctx.fp8 and ctx.is_input_fp8: + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = (dq_fp8._data, dk_fp8._data, dv_fp8._data) chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size_a2a, q.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], @@ -2694,20 +2515,21 @@ def backward(ctx, dout): ctx.cp_stream, False, ) + if ctx.fp8 and ctx.is_input_fp8: + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] if ctx.qkv_format == "bshd": dq, dk, dv = [x.view(ctx.batch_size, -1, *x.shape[-2:]) for x in [dq, dk, dv]] elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] if attn_dbias is not None: - # [b, np, sq, 2*cp, sk//(2*cp)] -> [b, np, sq, sk] + # [b, h, sq, 2*cp, sk//(2*cp)] -> [b, h, sq, sk] attn_dbias = attn_dbias.view(*attn_dbias.shape[:-2], -1) - # converting torch.uint8 to float8tensor - if ctx.fp8 and ctx.is_input_fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, fake_dtype=dout_dtype) - dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, fake_dtype=dout_dtype) - dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) - nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") + + nvtx_range_pop(f"{nvtx_label}") return ( None, @@ -2736,6 +2558,9 @@ def backward(ctx, dout): None, None, None, + None, + None, + None, ) @@ -2791,6 +2616,7 @@ def forward( attn_bias, deterministic, use_fused_attention, + return_max_logit, window_size, cp_group, cp_stream, @@ -2865,22 +2691,22 @@ def forward( else: cu_seqlens_q_padded = None - # [b, s, np, hn] -> [b, 2, s//2, np, hn] or [s, b, np, hn] -> [2, s//2, b, np, hn] + # [b, s, h, d] -> [b, 2, s//2, h, d] or [s, b, h, d] -> [2, s//2, b, h, d] q = q.view(*q.shape[:seq_dim], 2, q.shape[seq_dim] // 2, *q.shape[(seq_dim + 1) :]) - # [b, s, np, hn] or [s, b, np, hn] -> [s, b, np, hn] + # [b, s, h, d] or [s, b, h, d] -> [s, b, h, d] k, v = [x.movedim(seq_dim, 0).contiguous() for x in [k, v]] - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, cp_group) v_ag, _ = gather_along_first_dim(v, cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) cp_stream.wait_stream(torch.cuda.current_stream()) @@ -2896,12 +2722,14 @@ def forward( softmax_lse_per_step = [None, None] rng_states = [None, None] out = torch.empty_like(q) + max_logit_per_step = [None, None] + max_logit = None for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() kv_seq_range_per_step[i], window_size_per_step[i] = ( get_kv_seq_info_after_all_gather( @@ -2923,10 +2751,14 @@ def forward( k.shape[1], max_seqlen_kv_, k.device ) k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [s_range, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [s_range, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] if use_fused_attention: - out_per_step[i], [softmax_lse_per_step[i], rng_states[i]] = fused_attn_fwd( + ( + out_per_step[i], + [softmax_lse_per_step[i], rng_states[i]], + *max_logit_, + ) = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv_, @@ -2946,7 +2778,11 @@ def forward( cu_seqlens_q_padded=cu_seqlens_q_padded, cu_seqlens_kv_padded=cu_seqlens_kv_per_step[i], window_size=window_size_per_step[i], + return_max_logit=return_max_logit, + cuda_graph=is_graph_capturing(), ) + if return_max_logit: + max_logit_per_step[i] = max_logit_[0] else: fa_forward_args_thd = get_fa_args( True, @@ -2981,14 +2817,22 @@ def forward( if not use_flash_attn_3: rng_states[i] = fa_outputs[3] + if return_max_logit and i == 0: + max_logit = torch.clone(max_logit_per_step[0]) if i > 0: with torch.cuda.stream(flash_attn_streams[i - 1]): if qkv_format == "bshd": out[:, i - 1].copy_(out_per_step[i - 1]) elif qkv_format == "sbhd": out[i - 1].copy_(out_per_step[i - 1]) + if return_max_logit: + max_logit = torch.maximum(max_logit, max_logit_per_step[i - 1]) torch.cuda.current_stream().wait_stream(cp_stream) + if return_max_logit: + torch.distributed.all_reduce( + max_logit, op=torch.distributed.ReduceOp.MAX, group=cp_group + ) if use_fused_attention: if qkv_format == "bshd": @@ -3025,10 +2869,12 @@ def forward( ctx.use_fused_attention = use_fused_attention ctx.use_flash_attn_3 = use_flash_attn_3 nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") + if return_max_logit: + return out, max_logit return out @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *_args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) @@ -3059,17 +2905,17 @@ def backward(ctx, dout): # synchronize dkv update across steps dkv_update_done = torch.cuda.Event() - # [s, b, np, hn] -> [cp, s, b, np, hn] + # [s, b, h, d] -> [cp, s, b, h, d] k_ag, _ = gather_along_first_dim(k, ctx.cp_group) v_ag, _ = gather_along_first_dim(v, ctx.cp_group) - # [cp, s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp, s, b, h, d] -> [cp*2, s//2, b, h, d] k_ag = k_ag.view(2 * cp_size, k.shape[0] // 2, *k.shape[1:]) v_ag = v_ag.view(2 * cp_size, v.shape[0] // 2, *v.shape[1:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_before_attn(cp_size, k.device) k_ag = torch.index_select(k_ag, dim=0, index=chunk_ids_for_kv_ag) v_ag = torch.index_select(v_ag, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] k_ag = k_ag.view(-1, *k.shape[1:]) v_ag = v_ag.view(-1, *v.shape[1:]) ctx.cp_stream.wait_stream(torch.cuda.current_stream()) @@ -3110,8 +2956,8 @@ def backward(ctx, dout): for i in range(len(local_seq_chunk_ids) + 1): if i < len(local_seq_chunk_ids): with torch.cuda.stream(flash_attn_streams[i]): - # [b, 2, sq//2, np, hn] -> [b, sq//2, np, hn] - # or [2, sq//2, b, np, hn] -> [sq//2, b, np, hn] + # [b, 2, sq//2, h, d] -> [b, sq//2, h, d] + # or [2, sq//2, b, h, d] -> [sq//2, b, h, d] q_ = q.select(seq_dim, i).contiguous() seq_start_idx, seq_end_idx = ( kv_seq_range_per_step[i][0], @@ -3119,13 +2965,13 @@ def backward(ctx, dout): ) max_seqlen_kv = seq_end_idx - seq_start_idx k_, v_ = [x[seq_start_idx:seq_end_idx] for x in [k_ag, v_ag]] - # [cp*s, b, np, hn] -> [b, s_range, np, hn] or [s_range, b, np, hn] + # [cp*s, b, h, d] -> [b, s_range, h, d] or [s_range, b, h, d] k_, v_ = [x.movedim(0, seq_dim).contiguous() for x in [k_, v_]] out_ = out_per_step[i] dout_ = dout.select(seq_dim, i).contiguous().view(out_.shape) if ctx.use_fused_attention: aux_ctx_tensors = [softmax_lse_per_step[i], rng_states[i]] - dq_per_step[i], dk_per_step[i], dv_per_step[i], _ = fused_attn_bwd( + dq_per_step[i], dk_per_step[i], dv_per_step[i], *_ = fused_attn_bwd( ctx.max_seqlen_q, max_seqlen_kv, cu_seqlens_q, @@ -3148,6 +2994,7 @@ def backward(ctx, dout): attn_bias_type=ctx.attn_bias_type, window_size=window_size_per_step[i], deterministic=ctx.deterministic, + cuda_graph=is_graph_capturing(), ) else: dq_per_step[i], dk_per_step[i], dv_per_step[i] = [ @@ -3192,7 +3039,7 @@ def backward(ctx, dout): dq[:, i - 1].copy_(dq_per_step[i - 1]) elif ctx.qkv_format == "sbhd": dq[i - 1].copy_(dq_per_step[i - 1]) - # [b, s_range, np, hn] or [s_range, b, np, hn] -> [s_range, b, np, hn] + # [b, s_range, h, d] or [s_range, b, h, d] -> [s_range, b, h, d] dk_per_step[i - 1], dv_per_step[i - 1] = [ x.movedim(seq_dim, 0).contiguous() for x in [dk_per_step[i - 1], dv_per_step[i - 1]] @@ -3211,13 +3058,13 @@ def backward(ctx, dout): torch.cuda.current_stream().wait_stream(ctx.cp_stream) - # [cp*s, b, np, hn] -> [cp*2, s//2, b, np, hn] + # [cp*s, b, h, d] -> [cp*2, s//2, b, h, d] dk = dk.view(2 * cp_size, -1, *dk.shape[-3:]) dv = dv.view(2 * cp_size, -1, *dv.shape[-3:]) chunk_ids_for_kv_ag = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dk.device) dk = torch.index_select(dk, dim=0, index=chunk_ids_for_kv_ag) dv = torch.index_select(dv, dim=0, index=chunk_ids_for_kv_ag) - # [cp*2, s//2, b, np, hn] -> [cp*s, b, np, hn] + # [cp*2, s//2, b, h, d] -> [cp*s, b, h, d] dk = dk.view(-1, *dk.shape[-3:]) dv = dv.view(-1, *dv.shape[-3:]) dk, _ = reduce_scatter_along_first_dim(dk, ctx.cp_group) @@ -3249,6 +3096,7 @@ def backward(ctx, dout): None, None, None, + None, ) @@ -3279,6 +3127,7 @@ def forward( attn_bias, deterministic, use_fused_attention, + return_max_logit, window_size, fp8, fp8_meta, @@ -3286,6 +3135,9 @@ def forward( cp_stream, quantizers, use_flash_attn_3, + softmax_type, + softmax_offset, + fp8_output, ): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") @@ -3293,7 +3145,6 @@ def forward( softmax_scale = q.shape[-1] ** (-0.5) cp_size = get_distributed_world_size(cp_group) - qkv_dtype = q.dtype causal = "causal" in attn_mask_type padding = "padding" in attn_mask_type @@ -3357,32 +3208,38 @@ def forward( q.shape[seq_dim] % 2 == 0 and k.shape[seq_dim] % 2 == 0 ), "Sequence length per GPU needs to be divisible by 2!" + assert isinstance(k, q.__class__) and isinstance( + v, q.__class__ + ), "q, k, v must be of the same class, e.g. torch.Tensor or Float8Tensor." + is_input_fp8 = isinstance(q, Float8Tensor) + is_output_fp8 = fp8_output + is_bwd_fp8 = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # recipe passed in through autocast or set by NVTE_DPA_FP8_RECIPE; + # may be different from fp8_meta["recipe"] + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + fwd_nominal_dtype = q.dtype fused_attn_backend = None - # "fp8_mha" decides outputs in fp8, while inputs are inferred from the real dtype - is_input_fp8 = False - is_output_fp8 = False + max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + + q_fp8, k_fp8, v_fp8 = (None, None, None) if fp8: if use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - assert isinstance(k, q.__class__) and isinstance( - v, q.__class__ - ), "q, k, and v must have the same type." - is_input_fp8 = isinstance(q, Float8Tensor) - is_output_fp8 = fp8_meta is not None and fp8_meta["recipe"].fp8_mha if is_input_fp8: - QKV_quantizer = q._quantizer q_fp8, k_fp8, v_fp8 = q, k, v q, k, v = q_fp8._data, k_fp8._data, v_fp8._data - elif int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] + else: + q_fp8, k_fp8, v_fp8 = combine_and_quantize(qkv_layout, q, k, v, QKV_quantizer) + q, k, v = [q_fp8._data, k_fp8._data, v_fp8._data] fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = S_quantizer - fp8_meta_kwargs["o_quantizer"] = O_quantizer # partial result quantizer + fp8_meta_kwargs["o_quantizer"] = O_quantizer else: assert False, "FP8 is only supported with Fused Attention!" else: @@ -3394,25 +3251,23 @@ def forward( q, k, v = flash_attn_a2a_communicate( [q, k, v], chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, True ) + if softmax_type != "vanilla": + softmax_offset = flash_attn_a2a_communicate_softmax_offset( + softmax_offset, 1, cp_size, cp_group, cp_stream, True + ) - if fp8 and not is_input_fp8 and not int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_f16, k_f16, v_f16 = q, k, v - q, k, v = [QKV_quantizer(x)._data for x in [q_f16, k_f16, v_f16]] - + out_fp8 = None + out_f16 = None batch_size = q.shape[batch_dim] + q_part, k_part, v_part = q, k, v + out_part = None if use_fused_attention: - q_part, k_part, v_part = q, k, v if fp8: - q_part = QKV_quantizer.create_tensor_from_data( - q, fake_dtype=qkv_dtype, internal=True - ) - k_part = QKV_quantizer.create_tensor_from_data( - k, fake_dtype=qkv_dtype, internal=True - ) - v_part = QKV_quantizer.create_tensor_from_data( - v, fake_dtype=qkv_dtype, internal=True - ) - out, aux_ctx_tensors = fused_attn_fwd( + q_part, k_part, v_part = [ + Float8Tensor.make_like(x, data=y, dtype=fwd_nominal_dtype) + for x, y in zip([q_fp8, k_fp8, v_fp8], [q_part, k_part, v_part]) + ] + out_, aux_ctx_tensors, *max_logit = fused_attn_fwd( is_training, max_seqlen_q, max_seqlen_kv, @@ -3421,7 +3276,7 @@ def forward( q_part, k_part, v_part, - qkv_dtype, + fwd_nominal_dtype, fused_attn_backend, attn_scale=softmax_scale, dropout=dropout_p, @@ -3433,9 +3288,29 @@ def forward( cu_seqlens_kv_padded=cu_seqlens_kv_padded, window_size=window_size, **fp8_meta_kwargs, + softmax_type=softmax_type, + softmax_offset=softmax_offset, + return_max_logit=return_max_logit, + cuda_graph=is_graph_capturing(), ) - if fp8: - out = out._data + if isinstance(out_, Float8Tensor): + out_fp8 = out_ + out_ = out_._data + if is_bwd_fp8 and not ( + fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 + ): + out_part = out_fp8 + else: + out_part = out_fp8.dequantize(dtype=fwd_nominal_dtype) + else: + out_f16 = out_ + out_part = out_ + if ( + fp8 + and is_bwd_fp8 + and not (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) + ): + out_part = O_quantizer(out_) else: fa_forward_args_thd = get_fa_args( True, @@ -3447,67 +3322,71 @@ def forward( max_seqlen_kv=max_seqlen_kv, ) fa_outputs = flash_attn_fwd( - q, - k, - v, + q_part, + k_part, + v_part, *fa_forward_args_thd, causal=causal, **fa_forward_kwargs, ) if not fa_utils.v2_7_0_plus: - out, softmax_lse = fa_outputs[4], fa_outputs[5] + out_, softmax_lse = fa_outputs[4], fa_outputs[5] rng_state = fa_outputs[7] if not use_flash_attn_3 else None else: - out, softmax_lse = fa_outputs[0], fa_outputs[1] + out_, softmax_lse = fa_outputs[0], fa_outputs[1] rng_state = fa_outputs[3] if not use_flash_attn_3 else None aux_ctx_tensors = [softmax_lse, rng_state] + out_part = out_ - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out.device) - out = flash_attn_a2a_communicate( - out, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, out_.device) + out_ = flash_attn_a2a_communicate( + out_, chunk_ids_for_a2a, seq_dim, cp_size, cp_group, cp_stream, False ) + if return_max_logit: + max_logit = flash_attn_a2a_communicate_softmax_offset( + *max_logit, 0, cp_size, cp_group, cp_stream, False + ) if use_fused_attention: if qkv_format == "bshd": - # [b*s, np, hn] -> [b, s, np, hn] - out = out.view(batch_size, -1, *out.shape[-2:]) + # [b*s, h, d] -> [b, s, h, d] + out_ = out_.view(batch_size, -1, *out_.shape[-2:]) elif qkv_format == "sbhd": - # [s*b, np, hn] -> [s, b, np, hn] - out = out.view(-1, batch_size, *out.shape[-2:]) + # [s*b, h, d] -> [s, b, h, d] + out_ = out_.view(-1, batch_size, *out_.shape[-2:]) - if fp8: - if is_output_fp8: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=False - ) - out_ret = out_fp8 - out = out_fp8._data - else: - out_fp8 = O_quantizer.create_tensor_from_data( - out, fake_dtype=qkv_dtype, internal=True - ) - out_f16 = out_fp8.dequantize(dtype=qkv_dtype) - out_ret = out_f16 + if fp8 and use_fused_attention: + if fp8_recipe.float8_current_scaling(): + out_f16 = out_ + if is_output_fp8: + out_fp8 = O_quantizer(out_) + if fp8_recipe.delayed(): + out_fp8 = Float8Tensor.make_like(out_fp8, data=out_, dtype=fwd_nominal_dtype) + if not is_output_fp8: + out_f16 = out_fp8.dequantize(dtype=fwd_nominal_dtype) else: - out_ret = out + out_f16 = out_ - if not fp8 or int(os.getenv("NVTE_FP8_DPA_BWD", "1")): - q_save, k_save, v_save, out_save = q, k, v, out - else: - if is_input_fp8: - q_save, k_save, v_save = q, k, v - else: - q_save, k_save, v_save = q_f16, k_f16, v_f16 - if is_output_fp8: - out_save = out + out_ret = out_fp8 if is_output_fp8 else out_f16 + + ctx.fp8 = fp8 and is_bwd_fp8 + fp8_tensors = (None, None, None, None) + f16_tensors = (None, None, None, None) + if ctx.fp8: + if fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + fp8_tensors = (q_part, k_part, v_part, None) + f16_tensors = (None, None, None, out_part) else: - out_save = out_f16 + fp8_tensors = (q_part, k_part, v_part, out_part) + elif fp8: + q_part, k_part, v_part = combine_and_dequantize(qkv_layout, q_part, k_part, v_part) + f16_tensors = (q_part, k_part, v_part, out_part) + else: + f16_tensors = (q_part, k_part, v_part, out_part) tensors_to_save, tensor_objects = prepare_for_saving( - q_save, - k_save, - v_save, - out_save, + *fp8_tensors, + *f16_tensors, cu_seqlens_q, cu_seqlens_kv, cu_seqlens_q_padded, @@ -3516,6 +3395,7 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects + ctx.out_shape = out_ret.shape ctx.batch_size = batch_size ctx.cp_group = cp_group @@ -3530,13 +3410,14 @@ def forward( ctx.deterministic = deterministic ctx.window_size = window_size ctx.use_fused_attention = use_fused_attention - ctx.fp8 = fp8 and int(os.getenv("NVTE_FP8_DPA_BWD", "1")) ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + ctx.fwd_nominal_dtype = fwd_nominal_dtype + ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.softmax_type = softmax_type - ctx.qkv_dtype = qkv_dtype ctx.dQKV_quantizer = dQKV_quantizer ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer @@ -3551,15 +3432,21 @@ def forward( ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") + if return_max_logit: + return out_ret, max_logit return out_ret @staticmethod - def backward(ctx, dout): + def backward(ctx, dout, *_args): # pylint: disable=missing-function-docstring nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) ( + q_fp8, + k_fp8, + v_fp8, + out_fp8, q, k, v, @@ -3570,23 +3457,21 @@ def backward(ctx, dout): cu_seqlens_kv_padded, *aux_ctx_tensors, ) = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - qkv_layout = ctx.qkv_format + "_" + ctx.qkv_format + "_" + ctx.qkv_format causal = "causal" in ctx.attn_mask_type seq_dim = ctx.qkv_format.index("s") - dout_dtype = dout.dtype + bwd_nominal_dtype = ctx.fwd_nominal_dtype + dqkv_te_dtype = None fused_attn_backend = None - fused_attn_dqkv_dtype = None + dout_fp8 = dout if ctx.fp8: if ctx.use_fused_attention: fused_attn_backend = FusedAttnBackend["FP8"] - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - else: + if not isinstance(dout, QuantizedTensorStorage): dout = ctx.dO_quantizer(dout) - fused_attn_dqkv_dtype = TE_DType[dout._data.dtype] + dout_fp8 = dout + dqkv_te_dtype = dout._fp8_dtype dout = dout._data fp8_meta_kwargs = {} fp8_meta_kwargs["s_quantizer"] = ctx.S_quantizer @@ -3596,44 +3481,23 @@ def backward(ctx, dout): else: assert False, "FP8 is only supported with Fused Attention!" else: - if ctx.fp8_meta is not None: - if ctx.is_output_fp8: - assert isinstance(dout, Float8Tensor), "dout must be Float8Tensors for FP8 MHA!" - ctx.dO_quantizer = dout._quantizer - dout = dout._data - if ctx.is_input_fp8: - q = ctx.QKV_quantizer.create_tensor_from_data( - q, fake_dtype=ctx.qkv_dtype, internal=True - ) - k = ctx.QKV_quantizer.create_tensor_from_data( - k, fake_dtype=ctx.qkv_dtype, internal=True - ) - v = ctx.QKV_quantizer.create_tensor_from_data( - v, fake_dtype=ctx.qkv_dtype, internal=True - ) - q, k, v = [x.dequantize(dtype=ctx.qkv_dtype) for x in [q, k, v]] + if isinstance(dout, QuantizedTensorStorage): + dout = dout.dequantize(dtype=bwd_nominal_dtype) if ctx.use_fused_attention: fp8_meta_kwargs = {} - fused_attn_dqkv_dtype = TE_DType[dout_dtype] + dqkv_te_dtype = TE_DType[dout.dtype] fused_attn_backend = FusedAttnBackend["F16_arbitrary_seqlen" if not IS_HIP_EXTENSION else "CK"] if not ctx.use_fused_attention: out = out.view(ctx.batch_size, -1, *out.shape[-2:]) - dout = dout.view(*out.shape) + dout = dout.view(ctx.batch_size, -1, *dout.shape[-2:]) + else: + dout = dout.view(*ctx.out_shape) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, out.device) - out, dout = flash_attn_a2a_communicate( - [out, dout], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_before_attn(cp_size, dout.device) + dout = flash_attn_a2a_communicate( + dout, chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, True ) - if not ctx.fp8 and ctx.fp8_meta is not None and ctx.is_output_fp8: - out = ctx.O_quantizer.create_tensor_from_data( - out, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout = ctx.dO_quantizer.create_tensor_from_data( - dout, fake_dtype=dout_dtype, internal=True - ) - out = out.dequantize(dtype=ctx.qkv_dtype) - dout = dout.dequantize(dtype=dout_dtype) flash_attn_bwd = None if not ctx.use_fused_attention: @@ -3674,31 +3538,15 @@ def backward(ctx, dout): if fa_utils.v2_6_0_plus: fa_backward_kwargs["softcap"] = 0.0 + dq_fp8, dk_fp8, dv_fp8 = None, None, None if ctx.use_fused_attention: - q_part = q - k_part = k - v_part = v - out_part = out - dout_part = dout - + q_part, k_part, v_part, out_part, dout_part = q, k, v, out, dout if ctx.fp8: - q_part = ctx.QKV_quantizer.create_tensor_from_data( - q_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - k_part = ctx.QKV_quantizer.create_tensor_from_data( - k_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - v_part = ctx.QKV_quantizer.create_tensor_from_data( - v_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - out_part = ctx.O_quantizer.create_tensor_from_data( - out_part, fake_dtype=ctx.qkv_dtype, internal=True - ) - dout_part = ctx.dO_quantizer.create_tensor_from_data( - dout_part, fake_dtype=dout_dtype, internal=True - ) - - dq, dk, dv, _ = fused_attn_bwd( + q_part, k_part, v_part, out_part = q_fp8, k_fp8, v_fp8, out_fp8 + if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + out_part = out + dout_part = Float8Tensor.make_like(dout_fp8, data=dout, dtype=bwd_nominal_dtype) + dq, dk, dv, *rest = fused_attn_bwd( ctx.max_seqlen_q, ctx.max_seqlen_kv, cu_seqlens_q, @@ -3708,8 +3556,8 @@ def backward(ctx, dout): v_part, out_part, dout_part, - dout_dtype, - fused_attn_dqkv_dtype, + bwd_nominal_dtype, + dqkv_te_dtype, aux_ctx_tensors, fused_attn_backend, cu_seqlens_q_padded=cu_seqlens_q_padded, @@ -3721,12 +3569,13 @@ def backward(ctx, dout): attn_bias_type=ctx.attn_bias_type, window_size=ctx.window_size, deterministic=ctx.deterministic, + cuda_graph=is_graph_capturing(), **fp8_meta_kwargs, + softmax_type=ctx.softmax_type, ) - if ctx.fp8: - dq = dq._data - dk = dk._data - dv = dv._data + if isinstance(dq, Float8Tensor): + dq_fp8, dk_fp8, dv_fp8 = dq, dk, dv + dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] @@ -3756,7 +3605,7 @@ def backward(ctx, dout): **fa_backward_kwargs, ) - chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, q.device) + chunk_ids_for_a2a = get_seq_chunk_ids_for_reordering_after_attn(cp_size, dq.device) dq, dk, dv = flash_attn_a2a_communicate( [dq, dk, dv], chunk_ids_for_a2a, seq_dim, cp_size, ctx.cp_group, ctx.cp_stream, False ) @@ -3766,18 +3615,34 @@ def backward(ctx, dout): elif ctx.qkv_format == "sbhd": dq, dk, dv = [x.view(-1, ctx.batch_size, *x.shape[-2:]) for x in [dq, dk, dv]] + d_bias = None + d_softmax_offset = None + if ctx.use_fused_attention: + if ctx.attn_bias_type not in ["no_bias", "alibi"]: + d_bias = rest[0] + if ctx.softmax_type != "vanilla": + d_softmax_offset = rest[1] + d_softmax_offset = flash_attn_a2a_communicate_softmax_offset( + d_softmax_offset, 1, cp_size, ctx.cp_group, ctx.cp_stream, False + ) + if ctx.fp8: - dq = ctx.dQKV_quantizer.create_tensor_from_data( - dq, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dk = ctx.dQKV_quantizer.create_tensor_from_data( - dk, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - dv = ctx.dQKV_quantizer.create_tensor_from_data( - dv, fake_dtype=dout_dtype, internal=not ctx.is_input_fp8 - ) - if not ctx.is_input_fp8: - dq, dk, dv = [x.dequantize(dtype=dout_dtype) for x in [dq, dk, dv]] + if ctx.fp8_recipe.float8_current_scaling() and ctx.is_input_fp8: + dq, dk, dv = combine_and_quantize(qkv_layout, dq, dk, dv, ctx.dQKV_quantizer) + if ctx.fp8_recipe.delayed(): + dq, dk, dv = [ + Float8Tensor.make_like(x, data=y, dtype=bwd_nominal_dtype) + for x, y in zip([dq_fp8, dk_fp8, dv_fp8], [dq, dk, dv]) + ] + if not ctx.is_input_fp8: + dq, dk, dv = combine_and_dequantize( + qkv_layout, + dq, + dk, + dv, + src_nominal_dtype=bwd_nominal_dtype, + ) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( @@ -3796,6 +3661,8 @@ def backward(ctx, dout): None, None, None, + d_bias, + None, None, None, None, @@ -3806,6 +3673,8 @@ def backward(ctx, dout): None, None, None, + d_softmax_offset, + None, ) @@ -3838,6 +3707,11 @@ def attn_forward_func_with_cp( quantizers=None, pad_between_seqs=False, use_flash_attn_3=False, + softmax_type="vanilla", + softmax_offset=None, + fp8_output=False, + layer_number=1, + return_max_logit=False, ) -> torch.Tensor: """ Attention implementation with context parallelism (CP). CP partitions tensors along the sequence @@ -3901,10 +3775,15 @@ def attn_forward_func_with_cp( """ if cp_comm_type == "a2a+p2p": - assert isinstance( - cp_group, list - ), "Hierarchical CP implementation needs multi-level CP groups!" - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" + assert ( + isinstance(cp_group, list) and len(cp_group) == 2 + ), "CP implementation a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" + assert ( + qkv_format != "thd" + ), f"{qkv_format} format is not supported with hierarchical CP implementation yet!" + assert ( + attn_bias_type == "no_bias" + ), f"{attn_bias_type} bias type is not supported with hierarchical CP implementation yet!" if get_distributed_world_size(cp_group[0]) == 1: cp_group = cp_group[1] cp_comm_type = "p2p" @@ -3914,23 +3793,23 @@ def attn_forward_func_with_cp( else: assert isinstance( cp_group, dist_group_type - ), f"Unsupported process group for CP communication type {cp_comm_type}!" + ), f"cp_group must be {dist_group_type} type for {cp_comm_type=}!" assert qkv_format in [ "bshd", "sbhd", "thd", - ], f"QKV format of {qkv_format} is not supported with context parallelism!" + ], f"Context parallelism does not support {qkv_format=}!" assert ( qkv_format != "sbhd" or use_fused_attention - ), "FlashAttention does not support sbhd format!" + ), "Context parallelism does not support FlashAttention backend with qkv_format = 'sbhd'!" assert attn_bias is None or (use_fused_attention and "padding" not in attn_mask_type), ( - """Attention bias is only supported with FusedAttention and "causal" """ - """or "no_mask" mask types!""" + "Context parallelism only supports attention bias with FusedAttention backend and" + " non-padding mask types!" ) assert qkv_format != "thd" or ( cu_seqlens_q_padded is not None and cu_seqlens_kv_padded is not None - ), "cu_seqlens_padded cannot be None with context parallelism + THD format!" + ), "cu_seqlens_padded can not be None for context parallelism and qkv_format = 'thd'!" sliding_window_attn = ( window_size is not None and window_size != (-1, 0) and window_size != (-1, -1) @@ -3938,13 +3817,28 @@ def attn_forward_func_with_cp( assert not sliding_window_attn or cp_comm_type in [ "a2a", "all_gather", - ], "The context parallel running configs cannot support sliding window attetnion!" + ], "Context parallelism does not support sliding window attention with {cp_comm_type=}!" enable_mla = k.shape[-1] != v.shape[-1] assert not enable_mla or cp_comm_type in [ "p2p", "a2a+p2p", - ], "The context parallel running configs cannot support MLA!" + ], "Context parallelism does not support MLA with {cp_comm_type=}!" + + if fp8 and fp8_meta is not None: + if fp8_meta["recipe"].fp8_dpa: + assert ( + softmax_type == "vanilla" + ), "Context parallelism does not support {softmax_type=} with FP8 attention!" + assert ( + softmax_type == "vanilla" or use_fused_attention + ), "Context parallelism only supports {softmax_type=} with FusedAttention backend!" + assert ( + softmax_type == "vanilla" or cp_comm_type == "a2a" + ), "Context parallelism only supports {softmax_type=} with cp_comm_type = 'a2a'!" + assert ( + softmax_type == "vanilla" or qkv_format != "thd" + ), "Context parallelism does not support {softmax_type=} with qkv_format = 'thd'!" args = [ is_training, @@ -3965,6 +3859,7 @@ def attn_forward_func_with_cp( attn_bias, deterministic, use_fused_attention, + return_max_logit, ] if cp_comm_type in ["p2p", "a2a+p2p"]: @@ -3977,6 +3872,8 @@ def attn_forward_func_with_cp( quantizers, pad_between_seqs, use_flash_attn_3, + fp8_output, + layer_number, ] out = AttnFuncWithCPAndKVP2P.apply(*args) elif cp_comm_type == "all_gather": @@ -3985,7 +3882,18 @@ def attn_forward_func_with_cp( args += [window_size, cp_group, cp_stream, use_flash_attn_3] out = AttnFuncWithCPAndKVAllGather.apply(*args) elif cp_comm_type == "a2a": - args += [window_size, fp8, fp8_meta, cp_group, cp_stream, quantizers, use_flash_attn_3] + args += [ + window_size, + fp8, + fp8_meta, + cp_group, + cp_stream, + quantizers, + use_flash_attn_3, + softmax_type, + softmax_offset, + fp8_output, + ] out = AttnFuncWithCPAndQKVOA2A.apply(*args) else: raise ValueError(f"Unsupported communication type: {cp_comm_type}!") diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b35b87a83..ef601e4c4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -11,11 +11,26 @@ import logging import torch +from torch.nn.parameter import Parameter import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Format, + Recipe, + DelayedScaling, + Float8CurrentScaling, +) from transformer_engine.pytorch.utils import get_cudnn_version -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.quantization import ( + get_fp8_te_dtype, + FP8GlobalStateManager, + RecipeState, + DelayedScalingRecipeState, + MXFP8BlockScalingRecipeState, + Float8CurrentScalingRecipeState, + Float8BlockScalingRecipeState, +) +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.constants import ( @@ -72,6 +87,67 @@ "_alibi_bias_require_update": False, } +""" +This feature is **experimental** and subject to change. + +Some models may use different FP8 recipes for their linear layers and attention layers. To support this, +users can either use multiple, nested autocast() contexts to assign a distinct recipe for each layer, +or use a single autocast() for the non-attention layers and configure the recipe for the attention +layers as follows. + ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| Linear | Attention | Configuration | ++===================+===========+===================================================================================+ +| FP8DS/FP8CS/NVFP4 | FP16/BF16 | Pass FP8DS, FP8CS or NVFP4 to autocast(); | +| | | export NVTE_DPA_FP8_RECIPE="F16" | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8DS | Pass FP8DS to autocast(); | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8DS | Pass FP8CS to autocast(); | +| | | Attention FP8DS reuses the fp8_format, fp8_dpa, fp8_mha values from linear FP8CS; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8DS | Pass NVFP4 to autocast(); | +| | | Attention FP8DS reuses the fp8_dpa, fp8_mha values from linear NVFP4; | +| | | export NVTE_DPA_FP8_RECIPE="DelayedScaling" # switch to DS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8DS | FP8CS | Pass FP8DS to autocast(); | +| | | Attention uses FP8DS for S, dP tensors, and creates a new FP8CS recipe for QKV, O,| +| | | dO, dQKV tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8DS; | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| FP8CS | FP8CS | Pass FP8CS to autocast(); | +| | | Attention uses FP8CS for QKV, O, dO, dQKV tensors, and creates a new FP8DS recipe | +| | | for S, dP tensors based on fp8_format, fp8_dpa, fp8_mha from linear FP8CS and: | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +| NVFP4 | FP8CS | Pass NVFP4 to autocast(); | +| | | Attention creates a new FP8CS recipe for QKV, O, dO, dQKV, and a new FP8DS recipe | +| | | for S, dP, based on the fp8_dpa, fp8_mha values from linear NVFP4 and: | +| | | export NVTE_DPA_FP8_RECIPE="Float8CurrentScaling" # switch to CS | +| | | export NVTE_DPA_FP8_FORMAT="HYBRID" # or "E4M3", "E5M2" | +| | | export NVTE_DPA_FP8DS_AMAX_ALGO="most_recent" # or "max" | +| | | export NVTE_DPA_FP8DS_AMAX_HISTLEN=1 # or any other integer | +| | | export NVTE_DPA_FP8DS_REDUCE_AMAX=1 # or 0 | ++-------------------+-----------+-----------------------------------------------------------------------------------+ +""" +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +formats = {"HYBRID": Format.HYBRID, "E4M3": Format.E4M3, "E5M2": Format.E5M2} +_dpa_fp8_format = formats[os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID")] +_dpa_fp8ds_amax_algo = os.getenv("NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent") +_dpa_fp8ds_amax_histlen = int(os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1")) +_dpa_fp8ds_reduce_amax = os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") == "1" + + __all__ = ["DotProductAttention"] @@ -168,6 +244,23 @@ class DotProductAttention(TransformerEngineBaseModule): softmax_scale: Optional[float], default = `None` softmax scale for the attention scores. If `None`, defaults to `1.0/math.sqrt(kv_channels if isinstance(kv_channels, int) else kv_channels[0])`. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). + return_max_logit: Optional[bool], default = `False` + If true, returns the maximum attention score that can be used in a Muon optimizer to + rescale the Q and K projection weights (see `Muon is Scalable for LLM Training + `_). + max_logit = max(S), where S = mask(Q*K^T*softmax_scale + bias) in shape [b, h, s_q, s_kv], + and max_logit is in shape [h]. Parallelism parameters ---------------------- @@ -223,7 +316,11 @@ def __init__( cp_stream: torch.cuda.Stream = None, cp_comm_type: str = "p2p", softmax_scale: Optional[float] = None, + softmax_type: str = "vanilla", + return_max_logit: Optional[bool] = False, ) -> None: + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." super().__init__() self.logger = logging.getLogger("DotProductAttention") @@ -306,6 +403,21 @@ def __init__( self.attention_type = attention_type self.attention_dropout = attention_dropout + self.return_max_logit = return_max_logit + + self.softmax_type = softmax_type + if self.softmax_type == "vanilla": + self.softmax_offset = None + if self.softmax_type == "off-by-one": + self.softmax_offset = torch.zeros( + self.num_attention_heads // self.tp_size, device="cuda" + ) + if self.softmax_type == "learnable": + self.register_parameter( + "softmax_offset", + Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + get_rng_state_tracker=get_rng_state_tracker, + ) attn_kwargs = { "attention_dropout": attention_dropout, @@ -328,6 +440,8 @@ def __init__( layer_number=layer_number, deterministic=self.deterministic, **attn_kwargs, + softmax_type=self.softmax_type, + return_max_logit=self.return_max_logit, ) self.unfused_attention = UnfusedDotProductAttention( @@ -335,6 +449,8 @@ def __init__( attention_type=attention_type, **attn_kwargs, layer_number=layer_number, + softmax_type=self.softmax_type, + return_max_logit=self.return_max_logit, ) def remove_extra_states_check(self, incompatible_keys): # pylint: disable=unused-argument @@ -433,6 +549,234 @@ def set_context_parallel_group( self.cp_stream = cp_stream self.cp_comm_type = cp_comm_type + def init_fp8_metadata(self, num_gemms: int = 1) -> None: + """ + Override TransformerEngineBaseModule.init_fp8_metadata to allow for more flexible recipe support. + Initialize fp8 related metadata and tensors during fprop. + """ + _original_recipe = self.fp8_meta.get("recipe", None) + + # global recipe set in autocast() + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + if fp8_recipe.custom(): + return + + # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to + # a different recipe than fp8_recipe. DPA.quantizers may be a mix of different quantizers as well. + # + # fp8_recipe | NVTE_DPA_FP8_RECIPE | self.fp8_meta["recipe"] | self.quantizers + # -------------------------------------------------------------------------------------------- + # DelayedScaling (DS) | unset | DS | all DS + # Float8CurrentScaling (CS) | unset | DS | CS for QKV, O, dO, dQKV; DS for S, dP + # x={DS, CS} | y | refer to row x=y | refer to row x=y + fp8_recipe_dpa = fp8_recipe + fp8_recipes = fp8_recipe + if _dpa_fp8_recipe == "F16": + # ignore the recipe from autocast, set fp8_dpa = False, fp8_mha = False + fp8_recipe.fp8_dpa = False + fp8_recipe.fp8_mha = False + elif fp8_recipe.float8_current_scaling() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "DelayedScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format; construct a DS recipe + fake_recipe = DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = fp8_recipe_dpa + elif fp8_recipe.delayed() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe, and construct a CS+DS recipe + fake_recipes = [ + Float8CurrentScaling( + fp8_format=fp8_recipe.fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + fp8_recipe, + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + elif ( + fp8_recipe.float8_current_scaling() + and _dpa_fp8_recipe in ("", "Float8CurrentScaling") + and (fp8_recipe.fp8_dpa or fp8_recipe.fp8_mha) + ): + # use fp8_recipe for QKV, O, dO, dQKV, and construct a DS recipe for S, dP + # reuse fp8_format, fp8_dpa, fp8_mha from fp8_recipe + fake_recipe = DelayedScaling( + fp8_format=fp8_recipe.fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ) + fp8_recipe_dpa = fake_recipe + fp8_recipes = [fp8_recipe, fp8_recipe_dpa] + elif fp8_recipe.nvfp4() and _dpa_fp8_recipe == "Float8CurrentScaling": + # reuse fp8_dpa, fp8_mha from fp8_recipe but not fp8_format + # construct a CS recipe for QKV, O, dO, dQKV and a DS recipe for S, dP + fake_recipes = [ + Float8CurrentScaling( + fp8_format=_dpa_fp8_format, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + ), + DelayedScaling( + fp8_format=_dpa_fp8_format, + amax_history_len=_dpa_fp8ds_amax_histlen, + amax_compute_algo=_dpa_fp8ds_amax_algo, + fp8_dpa=fp8_recipe.fp8_dpa, + fp8_mha=fp8_recipe.fp8_mha, + reduce_amax=_dpa_fp8ds_reduce_amax, + ), + ] + fp8_recipe_dpa = fake_recipes[1] + fp8_recipes = fake_recipes + # DPA only support DS and CS; other recipes should have fp8_dpa=False, fp8_mha=False + if not fp8_recipe_dpa.float8_per_tensor_scaling(): + assert not ( + fp8_recipe_dpa.fp8_dpa or fp8_recipe_dpa.fp8_mha + ), f"DotProductAttention does not support {fp8_recipe_dpa.__class__.__name__} recipe" + + # reduce over TP+CP groups; expect fp8_group to be set up so + # assume attention uses the same fp8_group as GEMMs + fp8_group = FP8GlobalStateManager.get_fp8_group() + + self.fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + self.fp8 = FP8GlobalStateManager.is_fp8_enabled() + self.fp8_calibration = FP8GlobalStateManager.is_fp8_calibration() + fp8_enabled = self.fp8 or self.fp8_calibration + self.fp8_meta["fp8_checkpoint"] = self.fp8 or self.fp8_calibration + if self.fp8_parameters or fp8_enabled: + self.fp8_meta["global_recipe"] = fp8_recipe + self.fp8_meta["local_recipes"] = ( + fp8_recipes if isinstance(fp8_recipes, List) else [fp8_recipes] + ) + + if self.fp8_parameters or fp8_enabled: + if self.fp8_initialized and fp8_recipe_dpa == self.fp8_meta["recipe"]: + # FP8 init has already been run and recipe is the same, don't do anything. + return + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + else: + # If fp8 isn't enabled, turn off and return. + self.fp8_initialized = False + return + + if self.fp8_parameters and not self.fp8_initialized: + self.fp8_meta["num_gemms"] = num_gemms + self.init_fp8_meta_tensors(fp8_recipes) + + if fp8_enabled: + # Set FP8 and other FP8 metadata + self.fp8_meta["num_gemms"] = num_gemms + self.fp8_meta["fp8_group"] = fp8_group + + # Set FP8_MAX per tensor according to recipe + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + + # Allocate scales and amaxes + self.init_fp8_meta_tensors(fp8_recipes) + self.fp8_initialized = True + + self.fp8_meta["recipe"] = fp8_recipe_dpa + if fp8_recipe != fp8_recipe_dpa: + # fp8_recipe has changed, rehash the key. + autocast_key = FP8GlobalStateManager.get_unique_autocast_key( + fp8_recipe_dpa, fp8_group + ) + FP8GlobalStateManager.autocast_arguments[autocast_key] = ( + fp8_recipe_dpa, + fp8_group, + ) + + _current_recipe = self.fp8_meta["recipe"] + if _original_recipe is not None and not ( + issubclass(_current_recipe.__class__, _original_recipe.__class__) + or issubclass(_original_recipe.__class__, _current_recipe.__class__) + ): + warnings.warn( + f"Recipe type changed from {_original_recipe.__class__.__name__} " + f"to {_current_recipe.__class__.__name__}. " + "This may affect model behavior." + ) + # Clear cached workspaces as they were created with the old recipe/quantizer type + self._fp8_workspaces.clear() + + def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: + """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe): + recipe = [recipe] + fp8_recipe_dpa = recipe[-1] + fp8_meta_tensor_key = "scaling_fwd" if fwd else "scaling_bwd" + + # Return early if recipe state matches recipe + if self.fp8_meta_tensors_initialized: + recipe_state = self.fp8_meta[fp8_meta_tensor_key] + if fp8_recipe_dpa.delayed() and isinstance(recipe_state, DelayedScalingRecipeState): + self.adjust_amax_history_length(fp8_recipe_dpa.amax_history_len, fwd=fwd) + return + if fp8_recipe_dpa.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState): + return + if fp8_recipe_dpa.float8_current_scaling() and isinstance( + recipe_state, Float8CurrentScalingRecipeState + ): + return + if fp8_recipe_dpa.float8_block_scaling() and isinstance( + recipe_state, Float8BlockScalingRecipeState + ): + return + + # When fp8_recipe=Float8CurrentScaling, recipe=[CS, DS], and QKV/dQKV, O/dO use CS quantizers, S/dP use DS quantizers. + # See table above in init_fp8_metadata for more detail. + num_gemms = [2, 1] if len(recipe) == 2 else [3] + # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and + # 2 (grad_output and grad_input) for bwd + num_fp8_tensors = [x * 3 if fwd else x * 2 for x in num_gemms] + + # Initialize recipe state and quantizers + recipe_states = [ + RecipeState.create( + recipe[i], + mode=("forward" if fwd else "backward"), + num_quantizers=num_fp8_tensors[i], + ) + for i in range(len(recipe)) + ] + + self.fp8_meta[fp8_meta_tensor_key] = ( + recipe_states[-1] if len(recipe) == 2 else recipe_states[0] + ) + self.quantizers[fp8_meta_tensor_key] = [] + for recipe_state in recipe_states: + self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + @no_torch_dynamo(recursive=False) def forward( self, @@ -456,6 +800,7 @@ def forward( fast_zero_fill: bool = True, inference_params: Optional[InferenceParams] = None, pad_between_seqs: Optional[bool] = None, + fp8_output: Optional[bool] = False, ) -> torch.Tensor: """ Dot Product Attention Layer. @@ -628,12 +973,15 @@ def forward( pad_between_seqs: Optional[bool], default = `None` If None, inferred from qkv_format, cu_seqlens and cu_seqlens_padded. If true, there are padding tokens between individual sequences in a packed batch. + fp8_output: Optional[bool], default = `False` + Whether to enforce output to be in FP8 or not. """ with torch.cuda.device(query_layer.device), self.prepare_forward( query_layer, num_gemms=3, allow_non_contiguous=True, + allow_different_data_and_param_types=self.softmax_type != "vanilla", ) as query_layer: # checks for RNG if self.rng_states_tracker is not None and is_graph_capturing(): @@ -663,6 +1011,8 @@ def forward( tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2, ], """DotProductAttention only supports "E4M3" and "E5M2" FP8 data types.""" + else: + fp8_output = False # checks for q/k/v shapes assert ( @@ -922,6 +1272,7 @@ def forward( False ), "core_attention_bias must be in one of {bhss, 1hss, b1ss, 11ss} shapes" + # check if there is padding between sequences when qkv_format='thd' if pad_between_seqs is None: if qkv_format == "thd": pad_between_seqs = ( @@ -957,11 +1308,15 @@ def forward( pad_between_seqs=pad_between_seqs, attention_dropout=self.attention_dropout, context_parallel=context_parallel, + cp_comm_type=self.cp_comm_type, deterministic=self.deterministic, is_training=self.training, fp8=self.fp8, fp8_meta=self.fp8_meta, inference_params=inference_params, + softmax_type=self.softmax_type, + return_max_logit=self.return_max_logit, + cuda_graph=is_graph_capturing(), ) global _attention_backends if is_in_onnx_export_mode(): @@ -1022,6 +1377,12 @@ def forward( ) # run attention + softmax_offset = ( + self.softmax_offset.reshape(1, -1, 1, 1).to(torch.float32) + if self.softmax_offset is not None + else None + ) + if use_flash_attention: if core_attention_bias_type == "alibi": alibi_slopes, _ = dpa_utils.get_alibi( @@ -1053,6 +1414,7 @@ def forward( quantizers=self.quantizers, inference_params=inference_params, flash_attention_backend=flash_attention_backend, + fp8_output=fp8_output, ) if use_fused_attention: @@ -1071,7 +1433,6 @@ def forward( bias_dtype=query_layer.dtype, bottom_right_alignment=attn_mask_type not in ["causal", "padding_causal"], ) - # checkpoint_core_attention=False if checkpoint_core_attention: return self._checkpointed_attention_forward( self.fused_attention, @@ -1101,6 +1462,8 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8_output=fp8_output, ) return self.fused_attention( query_layer, @@ -1129,17 +1492,12 @@ def forward( quantizers=self.quantizers, pad_between_seqs=pad_between_seqs, inference_params=inference_params, - ) - - from transformer_engine.pytorch.cpu_offload import CPUOffloadEnabled - - if CPUOffloadEnabled: - warnings.warn( - "Attention activation Offloading is only implemented" - "with Flash Attention and Fused Attention!" + softmax_offset=softmax_offset, + fp8_output=fp8_output, ) if use_unfused_attention: + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" if checkpoint_core_attention: return self._checkpointed_attention_forward( self.unfused_attention, @@ -1150,6 +1508,8 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, @@ -1157,6 +1517,11 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return self.unfused_attention( _alibi_cache, @@ -1166,6 +1531,8 @@ def forward( qkv_layout=qkv_layout, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, + max_seqlen_q=max_seqlen_q, + max_seqlen_kv=max_seqlen_kv, attn_mask_type=attn_mask_type, attention_mask=attention_mask, window_size=window_size, @@ -1173,5 +1540,10 @@ def forward( core_attention_bias=core_attention_bias, alibi_slopes=alibi_slopes, inference_params=inference_params, + softmax_offset=softmax_offset, + fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa and allow_emulation, + fp8_meta=self.fp8_meta, + quantizers=self.quantizers, + fp8_output=fp8_output, ) return None diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1677689c1..54fe21d81 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -19,6 +19,7 @@ from packaging.version import Version as PkgVersion import torch +import torch.distributed as dist import torch.nn.functional as F from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex @@ -27,6 +28,7 @@ QKVLayout, AttnBiasType, AttnMaskType, + SoftmaxType, FusedAttnBackend, META_QKV, META_DQKV, @@ -34,18 +36,22 @@ META_DO, META_S, META_DP, - META_O_CP, - META_DQKV_CP, ) from transformer_engine.pytorch.attention.inference import InferenceParams -from transformer_engine.pytorch.float8_tensor import Float8Tensor -from transformer_engine.pytorch.fp8 import get_fp8_te_dtype +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Tensor, + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.quantization import get_fp8_te_dtype from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.utils import ( get_device_compute_capability, get_cudnn_version, + SplitAlongDim, + combine_tensors, ) from transformer_engine.pytorch.export import is_in_onnx_export_mode @@ -56,6 +62,9 @@ # NVTE_DEBUG_LEVEL = 0/1/2 # enables more and more verbose debug mode, default = 0 _NVTE_DEBUG_LEVEL = int(os.getenv("NVTE_DEBUG_LEVEL", "0")) _NVTE_FLASH_ATTN = int(os.getenv("NVTE_FLASH_ATTN", "1")) +# print quantizer info for a particular layer on a particular rank +_print_layer = int(os.getenv("NVTE_PRINT_LAYER_NUMBER", "1")) +_print_rank = int(os.getenv("NVTE_PRINT_RANK", "0")) _cu_seqlens_cache = {} @@ -109,7 +118,7 @@ class FlashAttentionUtils: version = PkgVersion("0") version_required = PkgVersion("2.1.1") version_required_blackwell = PkgVersion("2.7.3") - max_version = PkgVersion("2.8.1") + max_version = PkgVersion("2.8.3") v2_plus = False v2_1_plus = False v2_3_plus = False @@ -209,16 +218,24 @@ class AttentionParams: Attention dropout. context_parallel: bool, default = `False` Whether context parallelism is used or not. + cp_comm_type: str, default = "p2p" + The communication type of context parallelism. deterministic: bool, default = `False` Whether to run `DotProductAttention` with determinism or not. is_training: bool, default = `True` Whether in training mode (`True`) or inference mode (`False`) fp8: bool, default = `False` - Whether `DotProductAttention` is in an `fp8_autocast` region. + Whether `DotProductAttention` is in an `autocast` region. fp8_meta: Optional[Dict[str Any]], default = `None` The FP8 metadata tensor of `DotProductAttention`. inference_params: Optional[InferenceParams], default = `None` Inference-related parameters. See InferenceParams for details. + softmax_type: str, default = "vanilla" + The type of softmax operation. See DotProductAttention for details. + return_max_logit: bool, default = `False` + Whether to output max_logit. + cuda_graph: bool, default = `False` + Whether support for cuda graph capture is needed or not. """ qkv_type: Union[torch.Tensor, Float8Tensor] = torch.Tensor @@ -240,11 +257,15 @@ class AttentionParams: pad_between_seqs: bool = False attention_dropout: float = 0.0 context_parallel: bool = False + cp_comm_type: str = "p2p" deterministic: bool = False is_training: bool = True fp8: bool = False fp8_meta: Union[Dict[str, Any], None] = None inference_params: Optional[InferenceParams] = None + softmax_type: str = "vanilla" + return_max_logit: bool = False + cuda_graph: bool = False def __eq__(self, other): """ @@ -313,11 +334,15 @@ def get_attention_backend( pad_between_seqs = attention_params.pad_between_seqs attention_dropout = attention_params.attention_dropout context_parallel = attention_params.context_parallel + cp_comm_type = attention_params.cp_comm_type deterministic = attention_params.deterministic is_training = attention_params.is_training fp8 = attention_params.fp8 fp8_meta = attention_params.fp8_meta inference_params = attention_params.inference_params + softmax_type = attention_params.softmax_type + return_max_logit = attention_params.return_max_logit + cuda_graph = attention_params.cuda_graph # Run config logger = logging.getLogger("DotProductAttention") @@ -346,8 +371,31 @@ def get_attention_backend( field.name: getattr(attention_params, field.name) for field in fields(attention_params) } run_config.update(attention_params_dict) + # Add FP8 environment variables to config if fp8: + # all FP8 recipes: 1: (FP8 fwd, FP8 bwd), 0: (FP8 fwd, F16 bwd) run_config["NVTE_FP8_DPA_BWD"] = int(os.getenv("NVTE_FP8_DPA_BWD", "1")) + # Float8CurrentScaling: 1: use F16 O in bwd, 0: use FP8 O in bwd + run_config["NVTE_DPA_FP8CS_O_in_F16"] = int(os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1")) + # switch recipe to "F16", "DelayedScaling", or "Float8CurrentScaling" + _dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") + run_config["NVTE_DPA_FP8_RECIPE"] = _dpa_fp8_recipe + if _dpa_fp8_recipe != "": + # config new recipe if switched + run_config["NVTE_DPA_FP8_FORMAT"] = os.getenv("NVTE_DPA_FP8_FORMAT", "HYBRID") + run_config["NVTE_DPA_FP8DS_AMAX_ALGO"] = os.getenv( + "NVTE_DPA_FP8DS_AMAX_ALGO", "most_recent" + ) + run_config["NVTE_DPA_FP8DS_AMAX_HISTLEN"] = int( + os.getenv("NVTE_DPA_FP8DS_AMAX_HISTLEN", "1") + ) + run_config["NVTE_DPA_FP8DS_REDUCE_AMAX"] = int( + os.getenv("NVTE_DPA_FP8DS_REDUCE_AMAX", "1") + ) + # UnfusedDotProductAttention: 1: allow FP8 emulation, 0: do not allow + run_config["NVTE_UnfusedDPA_Emulate_FP8"] = int( + os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") + ) logger.debug("Running with config=%s", run_config) # The following sections check if `FlashAttention` supports the provided attention params, @@ -427,8 +475,60 @@ def get_attention_backend( logger.debug("Disabling FlashAttention 3 for FP8 training") use_flash_attention_3 = False if use_unfused_attention: - logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") + allow_emulation = os.getenv("NVTE_UnfusedDPA_Emulate_FP8", "0") == "1" + if not allow_emulation: + logger.debug("Disabling UnfusedDotProductAttention for FP8 attention") + use_unfused_attention = False + fp8_recipe = fp8_meta["recipe"] + if fp8_meta.get("local_recipes", None) is not None: + fp8_recipe = fp8_meta["local_recipes"][0] + if use_fused_attention and fp8_recipe.float8_current_scaling() and not IS_HIP_EXTENSION: + if device_compute_capability < (10, 0): + logger.debug("Disabling FusedAttention for FP8 current scaling on arch < sm100") + use_fused_attention = False + # TODO(cyanguwa): Modify the min cuDNN version supporting FP8 current scaling + # determinism for Blackwell + else: + if cudnn_version < (9, 14, 0): + logger.debug( + "Disabling FusedAttention for FP8 current scaling with cuDNN < 9.14.0" + ) + use_fused_attention = False + else: + if deterministic and cudnn_version < (9, 18, 0): + logger.debug( + "Disabling FusedAttention for FP8 current scaling requiring determinism" + " with cuDNN < 9.18.0" + ) + use_fused_attention = False + + if device_compute_capability == (12, 0) and not IS_HIP_EXTENSION: + if use_flash_attention: + logger.debug( + "Disabling FlashAttention as FP8 is not supported" + " for compute capability = sm120" + ) + if use_fused_attention: + logger.debug( + "Disabling FusedAttention as FP8 is not supported" + " for compute capability = sm120" + ) + use_flash_attention = False + use_fused_attention = False + + # Filter: Return max_logit + if return_max_logit: + if use_flash_attention: + use_flash_attention = False + logger.debug("Disabling FlashAttention for max_logit") + if use_fused_attention and qkv_format == "thd": + use_fused_attention = False + logger.debug("Disabling FusedAttention for max_logit with qkv_format = thd") + if fp8 and fp8_meta["recipe"].fp8_dpa: + use_flash_attention = False + use_fused_attention = False use_unfused_attention = False + logger.debug("Disabling all backends for max_logit with FP8 attention") # Filter: KV cache # backend | precision | KV cache | architecture | qkv_format | page_size @@ -495,6 +595,21 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + if ( + device_compute_capability == (12, 0) + and (head_dim_qk > 128 or head_dim_qk % 8 != 0) + and is_training + and not IS_HIP_EXTENSION + ): + if use_fused_attention: + logger.debug( + "Disabling FusedAttention as MLA for backward pass is not supported for compute" + " capability = sm120 for a head_dim_qk > 128 or head_dim_qk %%8 != 0. Found:" + " head_dim_qk = %s", + head_dim_qk, + ) + use_fused_attention = False + if use_flash_attention_2 and ( head_dim_qk > 256 or head_dim_qk % 8 != 0 @@ -565,12 +680,64 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False + if device_compute_capability == (12, 0) and not IS_HIP_EXTENSION: + if use_fused_attention: + logger.debug( + "Disabling FusedAttention as qkv_format = thd is" + " not supported for compute capability = sm120" + ) + use_fused_attention = False # Filter: Dropout if attention_dropout != 0.0 and use_flash_attention_3: logger.debug("Disabling FlashAttention 3 for dropout") use_flash_attention_3 = False + # Filter: Softmax type + # context_parallel | softmax_type | supported backends + # ---------------------------------------------------------------------------------------------------- + # no | vanilla | All + # no | off-by-one | FusedAttention, UnfusedDotProductAttention + # no | learnable | FusedAttention, UnfusedDotProductAttention + # yes | vanilla | FusedAttention, FlashAttention + # yes | off-by-one | FusedAttention + # yes | learnable | FusedAttention + if softmax_type != "vanilla": + logger.debug("Disabling FlashAttention for softmax_type = %s", softmax_type) + use_flash_attention = False + if fp8 and fp8_meta["recipe"].fp8_dpa: + logger.debug("Disabling FusedAttention for softmax_type = %s in FP8", softmax_type) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s in FP8", softmax_type + ) + use_unfused_attention = False + if qkv_format == "thd": + logger.debug( + "Disabling FusedAttention for softmax_type = %s and qkv_format = thd", softmax_type + ) + use_fused_attention = False + logger.debug( + "Disabling UnfusedDotProductAttention for softmax_type = %s and qkv_format = thd", + softmax_type, + ) + use_unfused_attention = False + if context_parallel: + logger.debug( + "Disabling UnfusedDotProductAttention for context parallelism with softmax_type" + " = %s", + softmax_type, + ) + use_unfused_attention = False + if cp_comm_type != "a2a": + logger.debug( + "Disabling FusedAttention for context parallelism with softmax_type = %s and" + " cp_comm_type = %s", + softmax_type, + cp_comm_type, + ) + use_fused_attention = False + # Filter: Context parallelism # qkv_format | attn_mask_type | attn_bias_type | supported backends # ---------------------------------------------------------------------------------------------------- @@ -814,6 +981,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt QKVLayout[qkv_layout], AttnBiasType[fu_core_attention_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], attention_dropout, num_heads, num_gqa_groups, @@ -823,6 +991,8 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt head_dim_v, window_size[0], window_size[1], + return_max_logit, + cuda_graph, ) if fused_attention_backend == FusedAttnBackend["No_Backend"]: logger.debug("Disabling FusedAttention as no backend supports the provided input") @@ -893,7 +1063,7 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") use_fused_attention = False fused_attention_backend = None - if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0): + if is_training and device_compute_capability >= (10, 0): logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") use_fused_attention = False fused_attention_backend = None @@ -1567,6 +1737,78 @@ def backward(ctx, grad_output): return None, None, _pack_tensor(indices, grad_output) +class ConvertTHDtoBSHD(torch.autograd.Function): + """ + Convert a tensor from qkv_format = thd to qkv_format = bshd. + """ + + @staticmethod + def forward(ctx, thd_tensor, cu_seqlens, max_seqlen): + # pylint: disable=missing-function-docstring + batch_size = cu_seqlens.shape[0] - 1 + if not thd_tensor.is_contiguous(): + thd_tensor = thd_tensor.contiguous() + bshd_tensor = tex.convert_thd_to_bshd( + thd_tensor, + cu_seqlens, + batch_size, + max_seqlen, + ) + ctx.save_for_backward(cu_seqlens) + ctx.num_tokens = thd_tensor.shape[0] + return bshd_tensor + + @staticmethod + def backward(ctx, bshd_tensor): + # pylint: disable=missing-function-docstring + (cu_seqlens,) = ctx.saved_tensors + if not bshd_tensor.is_contiguous(): + bshd_tensor = bshd_tensor.contiguous() + thd_tensor = tex.convert_bshd_to_thd( + bshd_tensor, + cu_seqlens, + ctx.num_tokens, + ) + return thd_tensor, None, None + + +class ConvertBSHDtoTHD(torch.autograd.Function): + """ + Convert a tensor from qkv_format = bshd to qkv_format = thd. + """ + + @staticmethod + def forward(ctx, bshd_tensor, cu_seqlens): + # pylint: disable=missing-function-docstring + num_tokens = cu_seqlens[-1] + max_seqlen = bshd_tensor.shape[1] + if not bshd_tensor.is_contiguous(): + bshd_tensor = bshd_tensor.contiguous() + thd_tensor = tex.convert_bshd_to_thd( + bshd_tensor, + cu_seqlens, + num_tokens, + ) + ctx.save_for_backward(cu_seqlens) + ctx.max_seqlen = max_seqlen + return thd_tensor + + @staticmethod + def backward(ctx, thd_tensor): + # pylint: disable=missing-function-docstring + (cu_seqlens,) = ctx.saved_tensors + batch_size = cu_seqlens.shape[0] - 1 + if not thd_tensor.is_contiguous(): + thd_tensor = thd_tensor.contiguous() + bshd_tensor = tex.convert_thd_to_bshd( + thd_tensor, + cu_seqlens, + batch_size, + ctx.max_seqlen, + ) + return bshd_tensor, None + + def get_qkv_format( qkv_layout: str = "bshd_bshd_bshd", inference_params: InferenceParams = None, @@ -1836,11 +2078,10 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): +def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: - num_of_nones = 8 if cp_specific_quantizers else 6 - return [None] * num_of_nones + return [None] * 6 QKV_quantizer = quantizers["scaling_fwd"][META_QKV] QKV_quantizer.internal = True QKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1849,6 +2090,7 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): S_quantizer = quantizers["scaling_fwd"][META_S] S_quantizer.internal = True S_quantizer.set_usage(rowwise=True, columnwise=False) + dQKV_quantizer = quantizers["scaling_bwd"][META_DQKV] dQKV_quantizer.interal = True dQKV_quantizer.set_usage(rowwise=True, columnwise=False) @@ -1858,22 +2100,158 @@ def get_attention_quantizers(fp8, quantizers, cp_specific_quantizers=False): dP_quantizer = quantizers["scaling_bwd"][META_DP] dP_quantizer.set_usage(rowwise=True, columnwise=False) dP_quantizer.interal = True - dQKV_CP_quantizer = quantizers["scaling_bwd"][META_DQKV_CP] - dQKV_CP_quantizer.set_usage(rowwise=True, columnwise=False) - dQKV_CP_quantizer.internal = True - O_CP_quantizer = quantizers["scaling_fwd"][META_O_CP] - O_CP_quantizer.set_usage(rowwise=True, columnwise=False) - - if cp_specific_quantizers: - return ( + + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + + +def print_quantizers( + label, + layer_number, + QKV_quantizer, + O_quantizer, + S_quantizer, + dQKV_quantizer, + dO_quantizer, + dP_quantizer, +): + """Print the type and scale/amax of attention quantizers""" + _to_print = _NVTE_DEBUG * _NVTE_DEBUG_LEVEL == 2 + if ( + _to_print + and _print_layer == layer_number + and ( + not dist.is_initialized() or (dist.is_initialized() and dist.get_rank() == _print_rank) + ) + ): + names = [ + "QKV_quantizer", + "S_quantizer", + "O_quantizer", + "dO_quantizer", + "dP_quantizer", + "dQKV_quantizer", + ] + quantizers = [ QKV_quantizer, - O_quantizer, - O_CP_quantizer, S_quantizer, - dQKV_quantizer, - dQKV_CP_quantizer, + O_quantizer, dO_quantizer, dP_quantizer, - ) + dQKV_quantizer, + ] + if "forward" in label: + names = names[:3] + quantizers = quantizers[:3] + if "backward" in label: + names = names[3:] + quantizers = quantizers[3:] + for i, q in enumerate(quantizers): + type_str = "" + if q is None: + type_str = "None" + elif isinstance(q, Float8Quantizer): + type_str = "DS" + elif isinstance(q, Float8CurrentScalingQuantizer): + type_str = "CS" + print( + f"{label} >> {names[i]:14s}: {type_str}, {q.scale.item():.4e} x" + f" {q.amax.item():.4e} = {q.scale.item()*q.amax.item():.4e}" + ) - return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer + +def combine_and_quantize(qkv_layout, q, k, v, qkv_quantizer): + """Combine q,k,v based on qkv_layout and quantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + src_nominal_dtype = q.dtype + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv = combine_tensors([q, k, v], dim) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = SplitAlongDim.apply(qkv_fp8._data, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv = combine_tensors([k, v], dim) + tensors = [q, kv] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, kv_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + k_data, v_data = SplitAlongDim.apply(kv_data, dim, [1, 1], True) + case 3: + tensors = [q, k, v] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv = torch.cat([x.view(-1) for x in tensors], dim=0) + qkv_fp8 = qkv_quantizer(qkv) + q_data, k_data, v_data = [ + qkv_fp8._data[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors) + ] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + + q_fp8, k_fp8, v_fp8 = [ + Float8Tensor.make_like(qkv_fp8, data=x, dtype=src_nominal_dtype) + for x in [q_data, k_data, v_data] + ] + + return q_fp8, k_fp8, v_fp8 + + +def combine_and_dequantize( + qkv_layout, q_fp8, k_fp8, v_fp8, src_nominal_dtype=None, des_nominal_dtype=None +): + """Combine q,k,v based on qkv_layout and dequantize them together""" + # 1: qkv packed, 2: kv packed, 3: qkv separate + qkv_layout = qkv_layout.replace("paged_kv_", "") + qkv_group = len(qkv_layout.split("_")) + if all(isinstance(x, Float8Tensor) for x in [q_fp8, k_fp8, v_fp8]): + src_nominal_dtype = q_fp8.dtype + else: + assert src_nominal_dtype is not None, "The nominal dtype of input tensors is required!" + if des_nominal_dtype is None: + des_nominal_dtype = src_nominal_dtype + + q_data, k_data, v_data = [x._data for x in [q_fp8, k_fp8, v_fp8]] + match qkv_group: + case 1: + dim = qkv_layout.find("3") + qkv_data = combine_tensors([q_data, k_data, v_data], dim) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = SplitAlongDim.apply(qkv, dim, [1, 1, 1], True) + case 2: + dim = qkv_layout.split("_")[1].find("2") + kv_data = combine_tensors([k_data, v_data], dim) + tensors = [q_data, kv_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, kv = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + k, v = SplitAlongDim.apply(kv, dim, [1, 1], True) + case 3: + tensors = [q_data, k_data, v_data] + num_tensors = len(tensors) + shapes = [x.shape for x in tensors] + numels = [x.numel() for x in tensors] + numels = [sum(numels[:i]) for i in range(num_tensors + 1)] + qkv_data = torch.cat([x.contiguous().reshape(-1) for x in tensors], dim=0) + qkv_fp8 = Float8Tensor.make_like(q_fp8, data=qkv_data, dtype=src_nominal_dtype) + qkv = qkv_fp8.dequantize(dtype=des_nominal_dtype) + q, k, v = [qkv[numels[i] : numels[i + 1]].view(shapes[i]) for i in range(num_tensors)] + case _: + raise RuntimeError("Invalid qkv_layout " + qkv_layout) + return q, k, v diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 5fd16bf1a..2440693df 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -3,13 +3,14 @@ # See LICENSE for license information. """Multi-head Attention.""" +import os import collections from typing import Callable, List, Optional, Tuple, Union import torch from transformer_engine.debug.pytorch.debug_state import TEDebugState -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.float8_tensor import Float8Tensor +from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization @@ -31,7 +32,15 @@ from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.attention.rope import apply_rotary_pos_emb -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor + +from transformer_engine.pytorch.cpu_offload import start_offload, is_cpu_offload_enabled + +# Force DotProductAttention to use a different recipe than the fp8_recipe set in autocast(). +# Useful when GEMMs and attention use different recipes. Supported values are "DelayedScaling" +# and "Float8CurrentScaling". Use other relevant variables here to define the recipe, e.g. fp8_dpa. +_dpa_fp8_recipe = os.getenv("NVTE_DPA_FP8_RECIPE", "") +_dpa_fp8_recipe_dpa = os.getenv("NVTE_DPA_FP8_RECIPE_DPA", "0") == "1" +_dpa_fp8_recipe_mha = os.getenv("NVTE_DPA_FP8_RECIPE_MHA", "0") == "1" class MultiheadAttention(torch.nn.Module): @@ -135,6 +144,17 @@ class MultiheadAttention(torch.nn.Module): For that, please use `get_qkv_layout` to gain the layout information. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -245,6 +265,7 @@ def __init__( qk_norm_before_rope: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -262,6 +283,7 @@ def __init__( self.return_bias = return_bias self.cp_size = 1 self.cp_rank = 0 + self.softmax_type = softmax_type kv_channels = kv_channels if kv_channels else (hidden_size // num_attention_heads) @@ -416,6 +438,7 @@ def __init__( tp_group=tp_group, layer_number=self.layer_number, attention_type=self.attention_type, + softmax_type=self.softmax_type, ) # Linear @@ -556,10 +579,12 @@ def set_context_parallel_group( self.cp_size = get_distributed_world_size(cp_group) self.cp_rank = get_distributed_rank(cp_group) elif isinstance(cp_group, list): - assert len(cp_group) == 2, "Current implementation only supports two-level CP groups!" assert ( cp_comm_type == "a2a+p2p" ), "Only cp_comm_type of a2a+p2p requires hierarchical CP groups!" + assert ( + len(cp_group) == 2 + ), "cp_comm_type = a2a+p2p requires cp_group = [a2a_cp_group, p2p_cp_group]!" cp_size_a2a = get_distributed_world_size(cp_group[0]) cp_rank_a2a = get_distributed_rank(cp_group[0]) cp_size_p2p = get_distributed_world_size(cp_group[1]) @@ -716,10 +741,22 @@ def forward( # Query, Key, and Value # ====================== - fp8_mha = ( - FP8GlobalStateManager.is_fp8_enabled() - and FP8GlobalStateManager.get_fp8_recipe().fp8_mha - ) + fp8 = FP8GlobalStateManager.is_fp8_enabled() + if _dpa_fp8_recipe == "": + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + fp8_dpa = fp8_recipe.fp8_dpa + fp8_mha = fp8_recipe.fp8_mha + float8_current_scaling = fp8_recipe.float8_current_scaling() + else: + fp8_dpa = _dpa_fp8_recipe_dpa + fp8_mha = _dpa_fp8_recipe_mha + float8_current_scaling = _dpa_fp8_recipe == "Float8CurrentScaling" + # QKV Gemm: do not produce FP8 output when in Float8CurrentScaling recipe + qkv_fp8_output = fp8 and fp8_mha and rotary_pos_emb is None and not float8_current_scaling + # DPA: always produce FP8 output when fp8=True to take advantage of the O amax + dpa_fp8_output = fp8 and (fp8_dpa or fp8_mha) + # Proj Gemm: match DPA output except for Float8CurrentScaling + proj_fp8_grad = dpa_fp8_output and not float8_current_scaling layernorm_output = None if self.attention_type == "self": @@ -728,7 +765,7 @@ def forward( layernorm_qkv_outputs = self.layernorm_qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: mixed_x_layer, layernorm_output = layernorm_qkv_outputs @@ -738,7 +775,7 @@ def forward( mixed_x_layer = self.qkv( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) num_queries_per_key_value = ( @@ -792,7 +829,7 @@ def forward( mixed_kv_layer = self.key_value( encoder_output, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.qkv_weight_interleaved: @@ -847,7 +884,7 @@ def forward( layernorm_query_outputs = self.layernorm_query( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) if self.return_layernorm_output: query_layer, layernorm_output = layernorm_query_outputs @@ -857,7 +894,7 @@ def forward( query_layer = self.query_layer( hidden_states, is_first_microbatch=is_first_microbatch, - fp8_output=fp8_mha and rotary_pos_emb is None, + fp8_output=qkv_fp8_output, ) # [sq, b, hp] --> [sq, b, np, hn] @@ -936,7 +973,8 @@ def forward( # =========================== # Core attention computation # =========================== - + if is_cpu_offload_enabled(): + start_offload(query_layer, key_layer, value_layer, offload_base_tensor=True) context_layer = self.core_attention( query_layer, key_layer, @@ -958,6 +996,7 @@ def forward( fast_zero_fill=fast_zero_fill, inference_params=inference_params, pad_between_seqs=pad_between_seqs, + fp8_output=dpa_fp8_output, ) # =================== @@ -966,7 +1005,7 @@ def forward( projection_output = self.proj( context_layer, is_first_microbatch=is_first_microbatch, - fp8_grad=isinstance(context_layer, QuantizedTensor), + fp8_grad=proj_fp8_grad, ) if self.return_bias: diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 139381f2d..0e1222c22 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -66,6 +66,9 @@ def forward(self, max_seq_len: int, offset: int = 0): """ Create rotary position embedding frequencies. + This function is particularly sensitive to the use of mixed precision, so we disable the + autocast context if it is enabled. + Parameters ---------- max_seq_len: int @@ -73,26 +76,27 @@ def forward(self, max_seq_len: int, offset: int = 0): offset: int, default = 0 Fixed offset for frequencies. """ - seq = ( - torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) - + offset - ) + with torch.autocast(enabled=False, device_type="cuda"): + seq = ( + torch.arange(max_seq_len, device=self.inv_freq.device, dtype=self.inv_freq.dtype) + + offset + ) - if ( - self.pretrained_max_position_embeddings is not None - and self.seq_len_interpolation_factor is not None - ): if ( - max_seq_len - > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + self.pretrained_max_position_embeddings is not None + and self.seq_len_interpolation_factor is not None ): - # dynamic linear scaling (length > position we have learned) - seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) - else: - # fixed linear scaling - seq *= 1 / self.seq_len_interpolation_factor - - freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) + if ( + max_seq_len + > self.pretrained_max_position_embeddings * self.seq_len_interpolation_factor + ): + # dynamic linear scaling (length > position we have learned) + seq *= 1 / (max_seq_len / self.pretrained_max_position_embeddings) + else: + # fixed linear scaling + seq *= 1 / self.seq_len_interpolation_factor + + freqs = torch.einsum("i , j -> i j", seq, self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size if not self.interleaved: @@ -145,7 +149,7 @@ def forward( cp_size, cp_rank, ) - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank @@ -156,10 +160,11 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" - freqs, cu_seqlens = ctx.saved_tensors + freqs, cu_seqlens, start_positions = ctx.saved_tensors grad_input = tex.fused_rope_backward( grad_output, freqs, + start_positions, QKVFormat[ctx.tensor_format], ctx.interleaved, cu_seqlens, @@ -167,7 +172,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.cp_rank, ) - return grad_input, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None class FusedQKVRoPEFunc(torch.autograd.Function): @@ -274,7 +279,6 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _apply_rotary_pos_emb_base( t: torch.Tensor, freqs: torch.Tensor, - start_positions: torch.Tensor = None, tensor_format: str = "sbhd", interleaved: bool = False, ) -> torch.Tensor: @@ -287,45 +291,19 @@ def _apply_rotary_pos_emb_base( Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional embedding will be applied. freqs: torch.Tensor - Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', - with `s2 >= s` and `d2 <= d`. - start_positions: torch.Tensor, default = None. - Tokens in a sequence `i` should be applied with position encoding offset by - `start_positions[i]`. If `start_positions=None`, there's no offset. + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]` + and dtype 'float', with `s2 >= s` and `d2 <= d`. tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape `[seq, bs, ...]`. interleaved: bool, default = False Whether to use interleaved rotary position embedding. """ - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # In case `start_positions` are provided, create a staggered `freqs` tensor - # offset by the values in `start_positions`. - # `start_positions` is only supported for `cp_size=1` and inference. - if start_positions is not None: - max_offset = torch.max(start_positions) - assert ( - max_offset + cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!" - - # Stack staggered rope embeddings along the batch dimension - freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) - - # Note that from this point, `freqs` has a shape `(s,b,1,d)`. - - # Only apply the rotary embeddings up to the sequence length of the running - # input. - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - # [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, b, 1, dim] -> [b, seq, 1, dim] if tensor_format == "bshd": freqs = freqs.transpose(0, 1) + # cos/sin first then dtype conversion for better precision cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) @@ -362,7 +340,7 @@ def _get_freqs_on_this_cp_rank( ) # cp_size == 1 - return freqs + return freqs[:seqlen] def apply_rotary_pos_emb( @@ -384,13 +362,13 @@ def apply_rotary_pos_emb( Training: qkv_formats: "thd", "bshd", "sbhd" context parallel: yes - start_positions: no + start_positions: yes interleaving: yes Inference: qkv_formats: "thd", "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- @@ -419,22 +397,17 @@ def apply_rotary_pos_emb( cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ - - # `start_positions` is only supported for `cp_size=1` and inference. - assert not ( - cp_size > 1 and start_positions is not None - ), """start_positions != None with CP SIZE > 1 is not supported!""" - assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." + # Fused apply rope logic for THD/BSHD/SBHD formats if fused: return FusedRoPEFunc.apply( t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank ) - # Unfused THD format + # Unfused apply rope logic for THD format if tensor_format == "thd": cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -443,15 +416,18 @@ def apply_rotary_pos_emb( # `s1hd` tensors (for each sequence) and applies rotary embedding to # those sequences individually. # Note that if `start_positions` is not `None`, then for each sequence, - # it's corresponding rope offset is also supplied from `start_positions` - # individually. + # the freqs supplied are offset by the corresponding `start_positions` value. return torch.cat( [ _apply_rotary_pos_emb_base( x.unsqueeze(1), - _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), - start_positions=( - start_positions[idx : idx + 1] if start_positions is not None else None + _get_freqs_on_this_cp_rank( + ( + freqs[start_positions[idx] :] if start_positions is not None else freqs + ), # offset the freqs + x.size(0), + cp_size, + cp_rank, ), interleaved=interleaved, ) @@ -459,17 +435,28 @@ def apply_rotary_pos_emb( ] ).squeeze(1) - # Unfused SBHD/BSHD format + # Unfused apply rope logic for SBHD/BSHD format follows ... + if tensor_format == "sbhd": seqlen = t.size(0) elif tensor_format == "bshd": seqlen = t.size(1) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + + if start_positions is not None: + max_offset = torch.max(start_positions) + assert ( + max_offset + seqlen * cp_size <= freqs.shape[0] + ), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!" + + # Stack staggered rope embeddings along the batch dimension + freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1) + # Note that from this point, `freqs` has a shape `(s,b,1,d)`. + return _apply_rotary_pos_emb_base( t, _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), - start_positions, tensor_format, interleaved=interleaved, ) @@ -501,7 +488,7 @@ def apply_fused_qkv_rotary_pos_emb( qkv_formats: "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- diff --git a/transformer_engine/pytorch/constants.py b/transformer_engine/pytorch/constants.py index 48dc1ba29..f51cf63a0 100644 --- a/transformer_engine/pytorch/constants.py +++ b/transformer_engine/pytorch/constants.py @@ -100,3 +100,5 @@ def __missing__(self, key): dist_group_type = torch.distributed.ProcessGroup MXFP8_BLOCK_SCALING_SIZE = 32 + +NVFP4_BLOCK_SCALING_SIZE = 16 diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 16fa9f3e8..e5492ebc6 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -15,9 +15,10 @@ NVTE_QKV_Format, NVTE_Bias_Type, NVTE_Mask_Type, + NVTE_Softmax_Type, NVTE_Fused_Attn_Backend, ) -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer __all__ = [ @@ -89,6 +90,12 @@ "padding_causal_bottom_right": NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, } +SoftmaxType = { + "vanilla": NVTE_Softmax_Type.NVTE_VANILLA_SOFTMAX, + "off-by-one": NVTE_Softmax_Type.NVTE_OFF_BY_ONE_SOFTMAX, + "learnable": NVTE_Softmax_Type.NVTE_LEARNABLE_SOFTMAX, +} + if not IS_HIP_EXTENSION: FusedAttnBackend = { "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, @@ -112,9 +119,6 @@ META_DO = tex.FP8BwdTensors.GRAD_INPUT2 META_S = tex.FP8FwdTensors.GEMM3_OUTPUT META_DP = tex.FP8BwdTensors.GRAD_INPUT3 -# repurpose some unused amax history buffers for partial results of CP fwd and bwd -META_O_CP = tex.FP8FwdTensors.GEMM2_OUTPUT -META_DQKV_CP = tex.FP8BwdTensors.GRAD_INPUT1 def fused_attn_fwd( is_training: bool, @@ -140,8 +144,12 @@ def fused_attn_fwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), rng_gen: torch.Generator = None, + softmax_offset: torch.Tensor = None, + return_max_logit: bool = False, + cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention FWD for separate QKV input. @@ -206,6 +214,8 @@ def fused_attn_fwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -214,6 +224,13 @@ def fused_attn_fwd( rng_gen: torch.Generator, default = None random number generator; if None, uses the default CUDA generator from PyTorch; otherwise, uses rng_gen + softmax_offset: torch.Tensor, default = None + softmax offset tensor in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. + return_max_logit: bool, default = False + whether to return the maximum attention score + cuda_graph: bool, default = False + whether or not cuda graph capture is enabled. Returns ---------- @@ -244,8 +261,13 @@ def fused_attn_fwd( rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen state of the random number generator; [seed, offset], dtype uint64 + max_logit: if return_max_logit = True, shape [h] and same data type as O; otherwise None """ + if IS_HIP_EXTENSION: + assert not return_max_logit, "ROCm does not support return_max_logit yet." + assert not cuda_graph, "ROCm does not support cuda_graph." + if attn_scale is None: d = q.size(-1) attn_scale = 1.0 / math.sqrt(d) @@ -299,6 +321,7 @@ def fused_attn_fwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, cu_seqlens_q, cu_seqlens_kv, @@ -313,10 +336,26 @@ def fused_attn_fwd( s_quantizer, o_quantizer, attn_bias, + softmax_offset, rng_gen, rng_elts_per_thread, + return_max_logit, + cuda_graph, ) + if return_max_logit: + qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + # thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] + # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] + stats = output_tensors[1] + torch.log(output_tensors[2]) + amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3) + # Max -> max_logit [h] + max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) + aux_ctx_tensors = [stats] + aux_ctx_tensors.extend(output_tensors[3:]) + return output_tensors[0], aux_ctx_tensors, max_logit + # out, aux_ctx_tensors return output_tensors[0], output_tensors[1:] @@ -346,8 +385,10 @@ def fused_attn_bwd( qkv_layout: str = "sbh3d", attn_bias_type: str = "no_bias", attn_mask_type: str = "padding", + softmax_type: str = "vanilla", window_size: Tuple[int, int] = (-1, -1), deterministic: bool = False, + cuda_graph: bool = False, ) -> Tuple[Union[torch.Tensor, None], ...]: """Fused Attention BWD for packed KV input. @@ -411,6 +452,8 @@ def fused_attn_bwd( type of the bias; {"no_bias", "pre_scale_bias", "post_scale_bias", "alibi"} attn_mask_type: str, default = "padding" type of the attention mask; {"padding", "causal", "padding_causal", "no_mask"} + softmax_type: str, default = "vanilla" + type of the attention softmax; {"vanilla", "off-by-one", "learnable"} window_size: Tuple[int, int], default = (-1, -1) sliding window size for local attention, where query at position i attends to keys in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q @@ -418,6 +461,8 @@ def fused_attn_bwd( window and causal mask specifically. deterministic: bool, default = False whether to execute the backward pass with deterministic behaviours. + cuda_graph: bool, default = False + whether or not cuda graph capture is enabled. Returns ---------- @@ -430,6 +475,9 @@ def fused_attn_bwd( d_bias: torch.Tensor, optional gradient tensor of Bias when attn_bias_type is "pre_scale_bias" or "post_scale_bias"; same data type and shape as Bias + d_softmax_offset: torch.Tensor, optional + gradient tensor of softmax offset in shape [1, h_q, 1, 1]. + See softmax_type in DotProductAttention for details. """ if attn_scale is None: d = q.size(-1) @@ -468,6 +516,7 @@ def fused_attn_bwd( QKVLayout[qkv_layout], AttnBiasType[attn_bias_type], AttnMaskType[attn_mask_type], + SoftmaxType[softmax_type], window_size, deterministic, cu_seqlens_q, @@ -485,6 +534,7 @@ def fused_attn_bwd( s_quantizer, dp_quantizer, dqkv_quantizer, + cuda_graph, ) return output_tensors diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index e4f4e619f..dd0411298 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -11,8 +11,10 @@ from ..constants import TE_DType from ..utils import get_sm_count, _empty_tensor -from ..tensor.quantized_tensor import Quantizer -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..quantized_tensor import Quantizer +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.utils import is_custom +from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer __all__ = [ @@ -77,6 +79,24 @@ def general_gemm( if not out.is_contiguous(): raise ValueError("Output tensor is not contiguous.") + # If A or B are custom tensors -> dispatch to quantizers's qgemm implementation + if is_custom(A) or is_custom(B): + return custom_gemm( + A, + B, + workspace, + out_dtype, + quantization_params, + gelu, + gelu_in, + accumulate, + layout, + out, + bias, + use_split_accumulator, + grad, + ) + debug_quantizer = None if isinstance(quantization_params, DebugQuantizer): debug_quantizer = quantization_params @@ -87,9 +107,9 @@ def general_gemm( # Use bfloat16 as default bias_dtype bias_dtype = TE_DType[torch.bfloat16 if bias is None else bias.dtype] - if isinstance(A, Float8BlockwiseQTensorBase) or isinstance(B, Float8BlockwiseQTensorBase): + if isinstance(A, Float8BlockwiseQTensorStorage) or isinstance(B, Float8BlockwiseQTensorStorage): # There is not use_split_accumulator == False - # implementation for Float8BlockwiseQTensorBase GEMM + # implementation for Float8BlockwiseQTensorStorage GEMM use_split_accumulator = True # Check that data format is supported diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 179c80a65..bfdee3475 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -3,685 +3,748 @@ # See LICENSE for license information. """Functionality for CPU offloading of tensors saved for backward pass.""" -from __future__ import annotations -from contextlib import nullcontext -from typing import Any, Dict, Optional +from __future__ import annotations +import contextlib +from collections import defaultdict +from dataclasses import dataclass, field +import os +import warnings +from typing import Any, Optional import torch - +from torch.autograd.graph import saved_tensors_hooks from transformer_engine.debug.pytorch.debug_state import TEDebugState -from .tensor.quantized_tensor import QuantizedTensorBase -from .tensor.float8_tensor import Float8Tensor - -__all__ = ["get_cpu_offload_context"] +import transformer_engine.pytorch as te +import transformer_engine.pytorch.cpu_offload_v1 as v1_code_path +from .quantized_tensor import ( + restore_from_saved, + prepare_for_saving, +) -CPUOffloadEnabled = False +__all__ = ["get_cpu_offload_context", "mark_not_offload", "start_offload"] -def mark_activation_offload(*tensors): - """Set the type of the offloading needed for a tensor.""" - if TEDebugState.debug_enabled: - raise RuntimeError("CPU offload is not supported in debug mode.") - - for tensor in tensors: - if tensor is None: - continue - if type(tensor) in [torch.Tensor, torch.nn.Parameter]: - tensor.activation_offloading = True - else: - data_tensors = tensor.get_data_tensors() - for tensor in data_tensors: - if tensor is not None: - tensor.activation_offloading = True - # This is a hack to force clear the tensor after it is offloaded. - # It is needed, because .*TensorBase classes are saved in the ctx, - # and they contain the reference to their data tensors. - tensor.needs_force_clear = True +NVTE_CPU_OFFLOAD_V1 = os.environ.get("NVTE_CPU_OFFLOAD_V1", "0") == "1" +OFFLOAD_SYNCHRONIZER = None -def is_cpu_offload_enabled() -> bool: - """Check if CPU offloading is currently enabled.""" - return CPUOffloadEnabled +def is_cpu_offload_enabled(): + """Returns True if CPU offload is enabled.""" + if NVTE_CPU_OFFLOAD_V1: + return v1_code_path.is_cpu_offload_enabled() + return OFFLOAD_SYNCHRONIZER is not None -class CpuOffloadSavedTensorHook: - """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. - In this context, the ``on_save_for_backward`` method will be called every time - a tensor is saved for backward (this includes intermediary results saved using - :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but - also those recorded by a PyTorch-defined operation). +def mark_activation_offload(*tensors): + """Set the type of the offloading needed for a tensor.""" + if NVTE_CPU_OFFLOAD_V1: + v1_code_path.mark_activation_offload(*tensors) - The ``on_get_saved_tensors`` method will be called when the backward function - of this op attempts to retrieve the saved tensor from context (this includes - :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the - as input the return value of the ``on_save_for_backward``, and is meant to return - an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of - size, device and element values. - Example: +def mark_not_offload(*tensors: torch.Tensor): + """Marks tensors to prevent them from being offloaded.""" + if NVTE_CPU_OFFLOAD_V1: + return - >>> import torch - >>> from typing import Any - >>> - >>> class DummyHook(CpuOffloadSavedTensorHook): - ... - ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - ... logging.info("On save", tensor) - ... return (tensor,) - ... - ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - ... logging.info("On get", saved_state) - ... tensor, = saved_state - ... return tensor - ... - >>> a = torch.ones(5, requires_grad=True) - >>> b = torch.ones(5, requires_grad=True) * 2 - >>> with DummyHook(): - ... y = a * b - ... - On save tensor([1., 1., 1., 1., 1.], requires_grad=True) - On save tensor([2., 2., 2., 2., 2.], grad_fn=) - >>> y.sum().backward() - On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) - On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + tensors, tensor_obj = prepare_for_saving(*tensors) - """ - - def __init__(self) -> None: - self.inside_context = False + for tensor in tensors: + if tensor is not None: + setattr(tensor, "_TE_do_not_offload", True) - def __enter__(self): - global CPUOffloadEnabled - CPUOffloadEnabled = True + restore_from_saved(tensor_obj, tensors) - self.inside_context = True - torch._C._autograd._push_saved_tensors_default_hooks( - self.on_save_for_backward, self.on_get_saved_tensor - ) - def __exit__(self, *args: Any): - global CPUOffloadEnabled - CPUOffloadEnabled = False +def start_offload(*tensors: torch.Tensor, offload_base_tensor: bool = False): + """ + Marks point in on main stream where tensors are fully computed and ready to be offloaded. + If offload_base_tensor is True and the tensor is a view, the base tensor is offloaded + and reloaded - the stride and storage offset of the view are saved and restored after reload. + It is useful when multiple tensors are views of the same base tensor, + for example in MultiHeadAttention for interleaved q, k, v tensors. + """ + if NVTE_CPU_OFFLOAD_V1: + return - self.inside_context = False - torch._C._autograd._pop_saved_tensors_default_hooks() + def _mark_tensor_for_offload(t): + if t is None: + return + # Attach an event to mark when the tensor is ready for reload. + t.start_reload_event = torch.cuda.Event() + t.start_reload_event.record(torch.cuda.current_stream()) + if offload_base_tensor and t._base is not None: + setattr(t, "offload_base_tensor", True) - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - """On save for backward.""" - raise NotImplementedError( - "`on_save_for_backward: Callable[[torch.Tensor], Any]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks" - ) + tensors, tensor_obj = prepare_for_saving(*tensors) - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - """On get saved tensor.""" - raise NotImplementedError( - "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" - "is not implemented in CpuOffloadHook class. Inherit " - "this class and implement your custom hooks" - ) + for tensor in tensors: + _mark_tensor_for_offload(tensor) + restore_from_saved(tensor_obj, tensors) -class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): - """Context-manager that offloads/recovers tensors through an offload hander. - The hook just offloads/recovers the tensor object to the handler through `tensor_push` - and `tensor_pop` interface. How the offload-handler manages the offloading, recovering - or prefetching timing is transparent to this hook. +@dataclass +class TensorGroup: + """ + TensorGroup is a collection of tensors, events and auxiliary data. + It is used multiple times in the CPU offload code. """ - def __init__( - self, - offload_handler: OffloadHandler, - handler_extra_kwargs: Optional[Dict[str, Any]] = None, - debug: bool = False, - ) -> None: - if handler_extra_kwargs is None: - handler_extra_kwargs = {} - self.debug: bool = debug - self.offload_handler: OffloadHandler = offload_handler - self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs - super().__init__() - - def on_save_for_backward(self, tensor: torch.Tensor) -> Any: - retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) - return retrieve_identifier - - def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: - tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) - return tensor - + tensor_list: list[torch.Tensor] = field(default_factory=list) + events: list[torch.cuda.Event] = field(default_factory=list) + aux: Any = None -class OffloadHandler: - """A base class for CPU offload-handler.""" - def __init__(self) -> None: - pass +class TensorGroupProcessor: + """ + Suppose there is a tensor group T that needs to be offloaded. + Possibly we can switch T into (T_opt, aux), where T_opt is smaller and easier to offload, + offload T_opt, reload it and then restore T from (T_opt_reloaded, aux). - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: - """Tensor push.""" - raise NotImplementedError( - "`tensor_push is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_push." - ) + This class contains static methods that perform these optimizations - for example + deduplication of tensors and restoring duplicates after reload. + """ - def tensor_pop(self, tensor_tag: Any, **kwargs): - """Tensor pop.""" - raise NotImplementedError( - "`tensor_pop is not implented in OffloadHandler class. " - "Inherit this class and implement your custom tensor_pop." - ) + @staticmethod + def tensor_group_process_before_offload(tensor_group: TensorGroup) -> tuple[TensorGroup, Any]: + """ + Call for a tensor group, just before offloading logic. + aux is a dictionary that contains auxiliary data, needed to restore pre-offload state. + """ + aux = {} + tensor_group = TensorGroupProcessor._switch_to_base_tensors(aux, tensor_group) + tensor_group = TensorGroupProcessor._deduplicate_tensors(aux, tensor_group) + return tensor_group, aux -class GroupCommitFunction(torch.autograd.Function): - """this is a dummy op with output identical to input. - However, it is necessary for marking a timepoint for offload handler to - accomplish all synchronizations. Implementing it as a function is necessary - because we need to actions in both forward and backward. - """ + @staticmethod + def tensor_group_process_after_reload(tensor_group: TensorGroup): + """ + Call for a tensor group, just after reload logic. + """ + assert tensor_group.aux is not None + tensor_group = TensorGroupProcessor._restore_tensor_duplicates(tensor_group) + tensor_group = TensorGroupProcessor._switch_to_views(tensor_group) + return tensor_group @staticmethod - def forward(ctx, tensor, cpu_offload_handler): - # pylint: disable=missing-function-docstring - cpu_offload_handler.on_group_commit_forward() - ctx.cpu_offload_handler = cpu_offload_handler - # return the identical tensor - return tensor + def _switch_to_base_tensors(aux, tensor_group: TensorGroup) -> TensorGroup: + """ + Changes tensors to base tensors and saves view options in aux. + + It we save multiple tensors which in fact are views of the same base tensor, + this will offload only this one base tensor. It is used for example in + MultiHeadAttention for interleaved q, k, v tensors. + """ + + def _check_if_offload_base_tensor(tensor: torch.Tensor) -> bool: + if getattr(tensor, "offload_base_tensor", False): + return True + if tensor._base is not None: + # If tensor is a view of a tensor and has the same elements, + # but with different strides, we can safely offload the base tensor. + # If tensor is a view on some part of a bigger tensor, + # the decision to offload the base tensor is non-trivial and we do not do it by default. + return tensor._base.numel() == tensor.numel() + return False + + aux["views"] = [] + for tensor_id in range( # pylint: disable=consider-using-enumerate + len(tensor_group.tensor_list) + ): + tensor = tensor_group.tensor_list[tensor_id] + if _check_if_offload_base_tensor(tensor): + aux["views"].append((tensor.shape, tensor.stride(), tensor.storage_offset())) + tensor = tensor._base + assert ( + tensor is not None + ), "Cannot offload base tensor, if the tensor is not a view." + tensor_group.tensor_list[tensor_id] = tensor + else: + aux["views"].append(None) + return tensor_group @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - cpu_offload_handler = ctx.cpu_offload_handler - cpu_offload_handler.on_group_commit_backward() - return grad_output, None + def _deduplicate_tensors(aux, tensor_group: TensorGroup) -> TensorGroup: + """ + Deduplicate tensors. + """ + dedup_tensors: list[torch.Tensor] = [] + dedup_events: list[torch.cuda.Event] = [] + tensor_to_index: dict[int, int] = {} + aux["original_tensor_ids"] = [] + # If there are several duplicates of the same tensor, with different events, + # we keep only first event - every event is recorded when the tensor is ready to be offloaded, + # so it is the most optimal to use the first event. + for tensor_id, tensor in enumerate(tensor_group.tensor_list): + if id(tensor) in tensor_to_index: + aux["original_tensor_ids"].append(tensor_to_index[id(tensor)]) + else: + tensor_to_index[id(tensor)] = len(dedup_tensors) + dedup_tensors.append(tensor) + dedup_events.append(tensor_group.events[tensor_id]) + aux["original_tensor_ids"].append(tensor_to_index[id(tensor)]) -group_prefetch_offload_commit = GroupCommitFunction.apply + tensor_group.tensor_list = dedup_tensors + tensor_group.events = dedup_events + return tensor_group + @staticmethod + def _restore_tensor_duplicates(tensor_group: TensorGroup) -> TensorGroup: + """ + Restore tensor duplicates. + """ + new_tensor_list = [] + new_events_list = [] + for tensor_id in range(len(tensor_group.aux["original_tensor_ids"])): + original_tensor_id = tensor_group.aux["original_tensor_ids"][tensor_id] + new_tensor_list.append(tensor_group.tensor_list[original_tensor_id]) + new_events_list.append(tensor_group.events[original_tensor_id]) + + tensor_group.tensor_list = new_tensor_list + tensor_group.events = new_events_list + return tensor_group -class SynchronizedGroupOffloadHandler(OffloadHandler): - """Offload Handler that offloads/reloads in a synchronized way. - The device-to-host and host-to-device copying happen in the same stream - as the computation kernels, thus the copying will block computation. + @staticmethod + def _switch_to_views(tensor_group: TensorGroup) -> TensorGroup: + """ + Switch to views - reverse of _switch_to_base_tensors. + """ + for tensor_id, tensor in enumerate(tensor_group.tensor_list): + if tensor_group.aux["views"][tensor_id] is not None: + tensor_group.tensor_list[tensor_id] = tensor.as_strided( + *tensor_group.aux["views"][tensor_id] + ) + return tensor_group + + +class OffloadableLayerState: + """ + Class that manages offloading and reloading of tensors for a single layer. """ def __init__( - self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False - ) -> None: - super().__init__() - - self.num_offload_group = num_offload_group - self.tensor_need_offloading_checker = tensor_need_offloading_checker - self.debug = debug - - self.groupid_reset() - - def groupid_reset(self): - """Groupid reset.""" - # Data structures to label saved tensors and book-keep their cpu copies. - # Currently, on push, create a new cpu tensor and copies; on pop, copies - # the tensor back to gpu and deletes the cpu tensor. - # These will increment whenever `group_commit()` is invoked - self.current_group, self.tensor_count_current_group = (0, 0) - self.torch_tensor_count = 0 - self.tensor_tag_to_state = {} - - def on_group_commit_forward(self): - """On group commit forward.""" - # finishing up with updating current group and tensor count - self.current_group += 1 # increment - self.tensor_count_current_group = 0 # reset - - def on_group_commit_backward(self): - """On group commit backward.""" - self.current_group -= 1 - assert self.current_group >= 0 - - @staticmethod - def offload(src_tensor, pin_memory=True): - """Offload.""" - - cpu_backup = torch.empty( - src_tensor.size(), - dtype=src_tensor.dtype, - layout=src_tensor.layout, - device="cpu", - pin_memory=pin_memory, + self, + offload_stream: torch.cuda.Stream, + retain_pinned_cpu_buffers: bool = False, + ): + self.offload_stream = offload_stream + self.retain_pinned_cpu_buffers = retain_pinned_cpu_buffers + + # There are 3 tensor groups: tensors on gpu before offload, + # tensors on cpu after offload, tensors on gpu after reload. + self.fwd_gpu_tensor_group = TensorGroup() + self.cpu_tensor_group = TensorGroup() + self.bwd_gpu_tensor_group = TensorGroup() + + self.aux: dict[str, Any] = {} + + # State can be one of: not_offloaded, offload_started, + # offload_finished, reload_started. + self.state = "not_offloaded" + + def _validate_state(self, func_name: str, allowed_states: list[str]): + assert ( + self.state in allowed_states + ), f"Invalid state: {self.state} for {func_name}, must be one of {allowed_states}" + + def start_offload(self): + """ + Start offloading of tensors. Puts copy from GPU to CPU tasks on offload stream. + Before each copy event, the offload stream waits for the event signalling that the tensor is ready to be offloaded. + This event is recorded in the start_offload or push_tensor call. + """ + self._validate_state(func_name="start_offload", allowed_states=["not_offloaded"]) + self.state = "offload_started" + + self.fwd_gpu_tensor_group, aux = TensorGroupProcessor.tensor_group_process_before_offload( + self.fwd_gpu_tensor_group ) - cpu_backup.copy_(src_tensor, non_blocking=pin_memory) - state = (src_tensor.device, cpu_backup) - return state + allocate_cpu_buffers = ( + not self.retain_pinned_cpu_buffers or len(self.cpu_tensor_group.tensor_list) == 0 + ) - @staticmethod - def reload(state, non_blocking=None, copy_buffer=None): - """Reload.""" - dev, cpu_backup = state - if non_blocking is None: - non_blocking = cpu_backup.is_pinned() + for tensor_id, tensor in enumerate(self.fwd_gpu_tensor_group.tensor_list): + assert tensor.is_contiguous() - if copy_buffer is None: - return cpu_backup.to(dev, non_blocking=non_blocking) + # Wait for the moment the tensor is ready to be offloaded. + self.offload_stream.wait_event(self.fwd_gpu_tensor_group.events[tensor_id]) # type: ignore[arg-type] - assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" + with torch.cuda.stream(self.offload_stream): + if allocate_cpu_buffers: + # empty_like is defined also for QuantizedTensors + offloaded_tensor = torch.empty_like( + tensor, device=torch.device("cpu"), pin_memory=True + ) + self.cpu_tensor_group.tensor_list.append(offloaded_tensor) + else: + assert self.cpu_tensor_group.tensor_list[tensor_id].shape == tensor.shape, ( + "CPU buffer shape does not match the offloaded tensor shape:" + f" {self.cpu_tensor_group.tensor_list[tensor_id].shape} != {tensor.shape} " + " Make sure that tensor shaped do not change between" + " iterations if retain_pinned_cpu_buffers is True." + ) + offloaded_tensor = self.cpu_tensor_group.tensor_list[tensor_id] + offloaded_tensor.copy_(tensor, non_blocking=True) + + # aux is a dictionary that contains auxiliary data like information which tensors were deduplicated, + # needed to restore pre-offload state after reload. + self.aux = aux + + self.finish_offload_event = torch.cuda.Event() + self.finish_offload_event.record(self.offload_stream) + + def release_activation_forward_gpu_memory(self): + """ + Release GPU memory of the activations. + Waits for offload to finish - memory needs to be kept alive when GPU->CPU copy is performed. + """ + self._validate_state( + func_name="release_activation_forward_gpu_memory", allowed_states=["offload_started"] + ) + self.state = "offload_finished" + + torch.cuda.current_stream().wait_event(self.finish_offload_event) # type: ignore[arg-type] + + # GPU memory can be released safely after the offload. + # Notice that the memory needs to be kept alive when GPU->CPU copy is performed. + self.fwd_gpu_tensor_group = TensorGroup() + del self.finish_offload_event + + def start_reload(self): + """ + Start reloading of tensors. + It allocates new tensors on GPU and puts copy from CPU tasks on offload stream. + """ + self._validate_state(func_name="start_reload", allowed_states=["offload_finished"]) + self.state = "reload_started" + + self.bwd_gpu_tensor_group = TensorGroup() + for tensor in self.cpu_tensor_group.tensor_list: + + # Notice that reloaded tensor is allocated on main stream, + # not offloaded stream. It is because PyTorch memory allocator + # cannot move tensors from pool of one stream to another without + # calling cudaFree and cudaMalloc again. + + # empty_like is defined also for QuantizedTensors. + reloaded_tensor = torch.empty_like(tensor, device=torch.device("cuda")) + self.offload_stream.wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(self.offload_stream): + reloaded_tensor.copy_(tensor, non_blocking=True) + + reload_tensor_event = torch.cuda.Event() + reload_tensor_event.record(self.offload_stream) + self.bwd_gpu_tensor_group.events.append(reload_tensor_event) + self.bwd_gpu_tensor_group.tensor_list.append(reloaded_tensor) + + self.bwd_gpu_tensor_group.aux = self.aux + self.bwd_gpu_tensor_group = TensorGroupProcessor.tensor_group_process_after_reload( + self.bwd_gpu_tensor_group + ) - copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + """ + It is called when a tensor is saved for backward pass. + + If tensor is offloaded, returns int representing the index of the tensor in the offloaded tensor group. + If tensor is not offloaded, returns the tensor itself. + """ + self._validate_state(func_name="push_tensor", allowed_states=["not_offloaded"]) + + if self._check_if_offload(tensor): + self.fwd_gpu_tensor_group.tensor_list.append(tensor) + # The group is processed and offloaded at the end of the forward pass of current layer. + # To enable offloading of tensors faster we use self.offload_stream and record + # the events when the tensors are ready to be offloaded. + # It means that we do not need to wait to the end of current layer to start offloading. + if hasattr(tensor, "start_reload_event"): + self.fwd_gpu_tensor_group.events.append(tensor.start_reload_event) + else: + self.fwd_gpu_tensor_group.events.append(torch.cuda.Event()) + self.fwd_gpu_tensor_group.events[-1].record(torch.cuda.current_stream()) + return len(self.fwd_gpu_tensor_group.tensor_list) - 1 + return tensor - return copy_buffer + def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + """ + It is called when a tensor is used in backward pass. + Returns the tensor. If tensor was offloaded/reloaded, wait for the reload of a tensor to finish. + """ + self._validate_state( + func_name="pop_tensor", allowed_states=["not_offloaded", "reload_started"] + ) - def tensor_push(self, tensor: torch.Tensor, **kwargs): - """Tensor push.""" - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - assert tensor_tag not in self.tensor_tag_to_state - if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( - tensor + # 1. tensor not offloaded + if isinstance(tensor_or_tensor_id, torch.Tensor): + return tensor_or_tensor_id + # 2. the layer was not offloaded at all + if self.state == "not_offloaded": + return self.fwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] + + # 3. the layer was offloaded + assert self.state == "reload_started" + # wait for the tensor to be reloaded + torch.cuda.current_stream().wait_event( + self.bwd_gpu_tensor_group.events[tensor_or_tensor_id] + ) + return self.bwd_gpu_tensor_group.tensor_list[tensor_or_tensor_id] + + def release_all_memory(self): + """Release all gpu and cpu memory the state stored. Is called after the backward pass.""" + self.fwd_gpu_tensor_group = TensorGroup() + if not self.retain_pinned_cpu_buffers: + self.cpu_tensor_group = TensorGroup() + self.bwd_gpu_tensor_group = TensorGroup() + self.state = "not_offloaded" + + def _check_if_offload(self, t: torch.Tensor) -> bool: + """ + Check if tensor needs to be offloaded. + """ + if ( + not isinstance(t, torch.nn.Parameter) + and not getattr(t, "_TE_do_not_offload", False) + and not isinstance(t, torch._subclasses.FakeTensor) + and t.device.type == "cuda" ): - state = SynchronizedGroupOffloadHandler.offload(tensor) - self.tensor_tag_to_state[tensor_tag] = state - else: - # will be offloaded together after group commit - self.tensor_tag_to_state[tensor_tag] = tensor - - return tensor_tag + if not t.is_contiguous() and not getattr(t, "offload_base_tensor", False): + warnings.warn( + "Tried to offload non-contiguous tensor, which is not supported. Offload of" + " this tensor will be skipped." + ) + return False + + return True + return False + + def get_offloaded_total_size_mb(self) -> float: + """ + Get total size of offloaded tensors in MB, used only for testing. + """ + + def get_tensor_size_mb(tensor): + if tensor is None: + return 0 + if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): + return sum(get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + return tensor.numel() * tensor.element_size() / (1024**2) + + total_size = 0 + for tensor in self.cpu_tensor_group.tensor_list: + total_size += get_tensor_size_mb(tensor) + return total_size + + +class OffloadSynchronizer: + """ + Base class responsible for synchronizing offloading and reloading of tensors for multiple layers. + In base class we only track layer number and + create OffloadableLayerState instances for all layers, but do not start offloading or reloading. + """ - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - assert tensor_tag in self.tensor_tag_to_state - state = self.tensor_tag_to_state.pop(tensor_tag) - if isinstance(state, tuple): - tensor = SynchronizedGroupOffloadHandler.reload(state) + def __init__( + self, + num_layers: int, + retain_pinned_cpu_buffers: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, + ): + self.num_layers = num_layers + self.offload_stream = offload_stream if offload_stream is not None else torch.cuda.Stream() + + self.layer_states = { + i: OffloadableLayerState(self.offload_stream, retain_pinned_cpu_buffers) + for i in range(num_layers) + } + + self.num_of_fwds = None + self.previous_bwd_layer_id = None + self.current_layer_id = None + + def fwd_step(self) -> int: + """ + Invoked before each layer forward. + """ + if self.num_of_fwds in [None, self.num_layers - 1]: + # reset the offload synchronizer + self.num_of_fwds = 0 else: - tensor = state - return tensor + self.num_of_fwds += 1 + self.current_layer_id = self.num_of_fwds + return self.current_layer_id + + def bwd_step(self, layer_num: int): + """ + Invoked before each layer backward. + """ + if self.previous_bwd_layer_id is not None: + self.layer_states[self.previous_bwd_layer_id].release_all_memory() + self.previous_bwd_layer_id = layer_num + self.current_layer_id = layer_num + + def push_tensor(self, tensor: torch.Tensor) -> int | torch.Tensor: + """Default push tensor method""" + return self.layer_states[self.num_of_fwds].push_tensor(tensor) + + def pop_tensor(self, tensor_or_tensor_id: torch.Tensor | int) -> torch.Tensor: + """Default pop tensor method""" + return self.layer_states[self.current_layer_id].pop_tensor(tensor_or_tensor_id) + + def finish_part_of_bwd(self): + """ + We need to release memory of backward - this call does that. + It needs to be invoked after every backward pass - there may be + more than one in pipeline parallelism. + + It is needed, because call bwd_step is invoked before each layer backward, + but we need to release memory after the backward pass is finished. + """ + if self.previous_bwd_layer_id is not None: + self.layer_states[self.previous_bwd_layer_id].release_all_memory() + self.previous_bwd_layer_id = None + + def get_offloaded_total_size_mb(self) -> float: + """ + Get total size of offloaded tensors in MB, used only for testing. + """ + return sum( + self.layer_states[layer_id].get_offloaded_total_size_mb() + for layer_id in self.layer_states + ) -class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): - """Compared to synchronize, this uses more memory because of the buffer but - achieves better performance due to the overlapping. D2h and h2d copying are - completely hidden behind computation if computation time of a layer is longer - than host-device communication time. Bulk offloading with delay and bulk reloading - with prefetch are implemented.""" +class DefaultOffloadSynchronizer(OffloadSynchronizer): + """ + Default implementation of OffloadSynchronizer, + intended to be used in standard training workloads - with multiple forwards + and multiple backwards. + """ def __init__( self, - num_offload_group, # must be <= actual number of groups (number of commits) - num_model_group, - tensor_need_offloading_checker=(lambda t: True), - double_buffering=False, - debug=False, - ) -> None: - super().__init__( - num_offload_group=num_offload_group, - tensor_need_offloading_checker=tensor_need_offloading_checker, - debug=debug, - ) - # Number of layers in the model - self.num_layers = num_model_group - # Data Structure to maintain reference to activation tensors - self.tensor_tag_to_buf = {} - # Data structure to hold the FP8/MXFP8 tensor objects - self.fp8_tensor_object_map = {} - self.float8_transpose_cache_valid = {} - self.dereferencing_list = [] - # Tracking the number of layers offloaded - self.offloaded_group_count = 0 - # Core data structure that decides the window for offloading - self.layer_window_map = {} - - # Data structures fo double buffered reloading - self.double_buffering = double_buffering - self.reload_double_buffer = [[], []] - self.double_buffer_created = False - - # Logic to make offloading load balance across computation - # for optimal CPU/GPU interconnect usage - constant = 0 - for i in range(self.num_offload_group): - self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 - if i < (self.num_layers % self.num_offload_group): - self.layer_window_map[i] += i + 1 - constant = i + 1 + num_layers: int, + num_offloaded_layers: int | None = None, + retain_pinned_cpu_buffers: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, + ): + super().__init__(num_layers, retain_pinned_cpu_buffers, offload_stream) + + # map of layers to bool meaning if layer needs to be offloaded + self.offload_layer_map: dict[int, bool] = {} + + # num_layer: int -> list of layers that need to finish offload by this moment + self.finish_offload_map: defaultdict[int, list[int]] = defaultdict(list) + # num_layer: int -> list of layers that need to start reload in this moment + self.start_reload_map: defaultdict[int, list[int]] = defaultdict(list) + + self._init_offload_synchronization_dicts(num_offloaded_layers) + + def _init_offload_synchronization_dicts(self, num_offloaded_layers: int): + """ + If synchronization dictionary is not provided, the number of offloaded layers is used to initialize + offload_layer_map, finish_offload_map and start_reload_map. + + The aim is to minimize memory usage by the end of the forward pass. + + The optimal strategy for that is to offload layers 0, ..., num_offloaded_layers - 1. + For layer i offload needs to finish before num_layers - num_offloaded_layers + i. + For layer i reload needs to start after num_layers - num_offloaded_layers + i. + + This ensures that - if all layers have memory footprint of T - then peak memory usage of saving activations is + (num_layers - num_offloaded_layers) * T. + """ + for layer_id in range(self.num_layers): + if layer_id < num_offloaded_layers: + self.offload_layer_map[layer_id] = True + self.finish_offload_map[self.num_layers - num_offloaded_layers + layer_id].append( + layer_id + ) + self.start_reload_map[self.num_layers - 1 - num_offloaded_layers + layer_id].append( + layer_id + ) else: - self.layer_window_map[i] += constant + self.offload_layer_map[layer_id] = False - # allocate streams and events for synchronization - self.d2h_stream = torch.cuda.Stream() - self.h2d_stream = torch.cuda.Stream() + def fwd_step(self) -> int: + """ + Invoked before each layer forward. + """ + super().fwd_step() + if self.offload_layer_map.get(self.current_layer_id - 1, False): + self.layer_states[self.current_layer_id - 1].start_offload() - def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + for layer in self.finish_offload_map[self.current_layer_id]: + self.layer_states[layer].release_activation_forward_gpu_memory() + return self.current_layer_id - torch_stray_tensor = isinstance( - tensor, - ( - torch._subclasses.fake_tensor.FakeTensor, - torch._subclasses.functional_tensor.FunctionalTensor, - ), - ) + def bwd_step(self, layer_num: int): + """ + Invoked before each layer backward. + """ + super().bwd_step(layer_num) - is_quantized_tensor = isinstance(tensor, QuantizedTensorBase) - - if not torch_stray_tensor: - - # obtain a unique tensor tag - tensor_tag = (self.current_group, self.tensor_count_current_group) - self.tensor_count_current_group += 1 - - assert tensor_tag not in self.tensor_tag_to_state - - if is_quantized_tensor: - tensor_list, _ = tensor.prepare_for_saving() - - self.tensor_tag_to_state[tensor_tag] = [] - self.tensor_tag_to_buf[tensor_tag] = [] - - # Added support for de-duplicating FP8 param tensors - for _, value in self.fp8_tensor_object_map.items(): - if tensor is value: - self.dereferencing_list.append(tensor_tag) - break - - self.fp8_tensor_object_map[tensor_tag] = tensor - if isinstance(tensor, Float8Tensor): - self.float8_transpose_cache_valid[tensor_tag] = getattr( - tensor, "_transpose_invalid" - ) - else: - tensor_list = [tensor] - - for t in tensor_list: - if is_quantized_tensor: - self.tensor_tag_to_state[tensor_tag].append(t) - else: - self.tensor_tag_to_state[tensor_tag] = t - - if ( - self.current_group < self.num_offload_group - and self.tensor_need_offloading_checker(t) - ): - if is_quantized_tensor: - self.tensor_tag_to_buf[tensor_tag].append(t) - # Need to clear the internal data reference for the quantized tensors - tensor.clear() - else: - self.tensor_tag_to_buf[tensor_tag] = t - else: - tensor_tag = (-1, self.torch_tensor_count) - self.torch_tensor_count += 1 - self.tensor_tag_to_state[tensor_tag] = tensor - - return tensor_tag - - def tensor_pop(self, tensor_tag, **kwargs): - """Tensor pop.""" - assert tensor_tag in self.tensor_tag_to_state - tensor = self.tensor_tag_to_state.pop(tensor_tag) - - # Handling the quantized tensor case specially here - if isinstance(tensor, list): - # If it's a duplicated tensor, we don't need to locally - # write back a tensor as it would already be written - if tensor_tag in self.dereferencing_list: - self.dereferencing_list.remove(tensor_tag) - else: - self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) - tensor = self.fp8_tensor_object_map.pop(tensor_tag) + for layer in self.start_reload_map[layer_num]: + self.layer_states[layer].start_reload() - if self.double_buffering: - tensor._do_not_clear = True - self.tensor_tag_to_buf.pop(tensor_tag, None) - # the tensor should have been copied back in on_group_commit_backward() - # which invokes bulk_reload_group. - assert not isinstance(tensor, tuple) - return tensor +class ManualOffloadSynchronizer(OffloadSynchronizer): + """ + Manual implementation of OffloadSynchronizer, + all synchronization is done manually by the user by using + one of the following methods: + - start_offload_layer + - release_activation_forward_gpu_memory + - start_reload_layer + + This implementation is intended to be used in more complex trainigs workflows. + It is useful for example in pipeline parallelism. + """ - def bulk_offload_group(self, group_to_offload): - """Bulk offload group.""" - with torch.cuda.stream(self.d2h_stream): - for tensor_tag, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_tag - if group_id == group_to_offload: - assert not isinstance(state, tuple) - - is_quantized_tensor = isinstance(state, list) - - if is_quantized_tensor: - tensor_list = state - self.tensor_tag_to_state[tensor_tag] = [] - else: - tensor_list = [state] - - for tensor_on_device in tensor_list: - # `tensor_offloaded` is a hacky way of dealing with columnwise-only - # quantized tensors for CPU offloading. The complication is due to - # the `rowwise_data` being `None`. The offloading checker incorrectly - # returns `False` and the entire `state` ([None, columnwise_tensor]) - # is added to the tensor tag state dict. A better design would change - # how quantized tensors are kept track of in the offload handler. - # Currently at every stage it is ensured that a quantized tensor is a - # list whereas a non-quantized tensor is standalone object, which is - # not good! TODO(@sanandaraj5597) - tensor_offloaded = False - # if offload, return the reference to cpu copy - if self.tensor_need_offloading_checker(tensor_on_device): - tensor_offloaded = True - state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) - if is_quantized_tensor: - if tensor_offloaded: - self.tensor_tag_to_state[tensor_tag].append(state) - else: - self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) - else: - self.tensor_tag_to_state[tensor_tag] = state - - def synchronize_on_group_commit_forward(self, current_group): - """Synchronize on group commit forward.""" - - # For the first group, kickstart the offload after we have - # the first compute completion - if current_group == 0: - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - - if not self.double_buffer_created: - # Creating the first copy of double buffer for tensors that are offloaded - for tensor_tag, buf in self.tensor_tag_to_buf.items(): - if isinstance(buf, list): - for b in buf: - self.reload_double_buffer[0].append( - torch.empty_like(b) if self.double_buffering else None - ) - else: - self.reload_double_buffer[0].append( - torch.empty_like(buf) if self.double_buffering else None - ) - - self.bulk_offload_group(current_group) - - # Window map data structure helps us synchronize based on number - # of layers offloaded - if self.layer_window_map[self.offloaded_group_count] == current_group: - - # Stream synchronization both ways - self.d2h_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.d2h_stream) - - # Time to free the activation memory after usage - for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): - if tensor_tag[0] == self.offloaded_group_count: - if hasattr(tensor_buf, "needs_force_clear"): - # Need to clear activation tensor - sometimes references persist in the code. - # This is the case for example with the Float8TensorBase class, - # which is saved directly inside the ctx while its internal tensors are - # saved inside save_for_backward. - tensor_buf.data = torch.Tensor() - # Release the pointer to the tensor - self.tensor_tag_to_buf[tensor_tag] = None - - # Time to offload the next group - if self.offloaded_group_count < (self.num_offload_group - 1): - self.bulk_offload_group(self.offloaded_group_count + 1) - - # Increment the offload group count to keep track - self.offloaded_group_count += 1 - - if not self.double_buffer_created: - # Creating second copy of double buffer for tensors that are offloaded - if current_group == (self.num_layers - 1): - for buf in self.reload_double_buffer[0]: - self.reload_double_buffer[1].append( - torch.empty_like(buf) if self.double_buffering else None - ) - self.double_buffer_created = True - - def on_group_commit_forward(self): - """This function will cause host device synchronization""" - # handle synchronization events - self.synchronize_on_group_commit_forward(self.current_group) - - super().on_group_commit_forward() - - def bulk_reload_group(self, group_to_reload): - """Bulk reload group.""" - assert group_to_reload < self.num_offload_group - - buffer_idx = 0 - double_buffer_idx = group_to_reload % 2 - - main_stream = torch.cuda.current_stream() - - with torch.cuda.stream(self.h2d_stream): - # move back tensors - for tensor_label, state in self.tensor_tag_to_state.items(): - group_id, _ = tensor_label - if group_id == group_to_reload: - - if isinstance(state, tuple): - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] - else: - with torch.cuda.stream(main_stream): - reload_buffer = torch.empty_like( - state[1], device=torch.cuda.current_device() - ) - - recovered_tensor = SynchronizedGroupOffloadHandler.reload( - state, True, reload_buffer - ) - buffer_idx = buffer_idx + 1 - self.tensor_tag_to_state[tensor_label] = recovered_tensor - elif isinstance(state, list): - tensor_list = [] - for state_tuple in state: - - if isinstance(state_tuple, tuple): - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][ - buffer_idx - ] - else: - with torch.cuda.stream(main_stream): - reload_buffer = torch.empty_like( - state_tuple[1], device=torch.cuda.current_device() - ) - - tensor_list.append( - SynchronizedGroupOffloadHandler.reload( - state_tuple, - True, - reload_buffer, - ) - ) - buffer_idx = buffer_idx + 1 - else: - tensor_list.append(state_tuple) - - # No need to write back the duplicated tensor againn - # to the same location, this check ensures that - if tensor_label in self.dereferencing_list: - self.dereferencing_list.remove(tensor_label) - else: - _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved( - tensor_list - ) - - if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): - self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( - self.float8_transpose_cache_valid.pop(tensor_label) - ) - - self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( - tensor_label - ) - - def on_group_commit_backward(self): - # first decrement the current group. - # after last commit in forward, the group will +1; in backward it -1. - # Finally it should be decremented to 0. - self.current_group -= 1 - assert self.current_group >= 0 - - # Layer window data structure helps us to reload at right times - if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: - - # Stream synchronization both ways - self.h2d_stream.wait_stream(torch.cuda.current_stream()) - torch.cuda.current_stream().wait_stream(self.h2d_stream) - - # Time to reload the next group - self.bulk_reload_group(self.offloaded_group_count - 1) - - # Decrease the offloading group counter - self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 - - # Last group computation needs to wait till all the reloads complete - if self.current_group == 0: - torch.cuda.current_stream().wait_stream(self.h2d_stream) - self.offloaded_group_count = 0 + def start_offload_layer(self, layer_id: int): + """ + Start offloading of the layer. + Each tensor GPU->CPU copy is done asynchronously on the offload stream. + Start of each copy is started after tensor_push() is called on the current stream. + """ + self.layer_states[layer_id].start_offload() + + def release_activation_forward_gpu_memory(self, layer_id: int): + """ + Release memory of the activations of the layer. + It waits for the offload of the layer to finish. + """ + self.layer_states[layer_id].release_activation_forward_gpu_memory() + + def start_reload_layer(self, layer_id: int): + """ + Start reloading of the layer. + Each tensor reload is awaited to finish before tensor_pop() for that tensor is called on the current stream. + """ + self.layer_states[layer_id].start_reload() def get_cpu_offload_context( enabled: bool = False, - num_layers: int = 1, + num_layers: Optional[int] = 1, model_layers: int = 1, offload_activations: bool = True, offload_weights: bool = False, - double_buffering: bool = False, + double_buffering: bool = False, # pylint: disable=unused-argument + manual_synchronization: bool = False, + retain_pinned_cpu_buffers: bool = False, + offload_stream: Optional[torch.cuda.Stream] = None, ): """ - This function returns the CPU Offload context and the synchronizer function that needs to be - used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + CPU Offloading feature for seqeuences of layers. Can be used for arbitrary layers, not necessarily + for these provided by the TE. Usage: .. code-block:: python - cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + cpu_offload_context, sync_function = get_cpu_offload_context(...) - with cpu_offload_context: - te_layer.forward(inp_tensor) - cpu_offload_synchronizer() + for _ in range(num_layers): + with cpu_offload_context: + x = layers[i].forward(x) + x = sync_function(x) Parameters ---------- enabled: bool, default = `False` When set to True, CPU Offloading functionality is enabled. num_layers: int, default = 1 - Determines the number of transformer layers - you want to offload activations/weights for. + Determines the number of layers + you want to offload activations/weights for. model_layers: int, default = 1 - Number of layers in the model that will be used under this context. + Number of layers in the model that will be used under this context. offload_activations: bool, default = `True` - When set to `True`, offloads the activations for the TE layer. + Deprecated. offload_weights: bool, default = `True` - When set to `True`, offloads the weights for the TE layer. + Deprecated. double_buffering: bool, default = `False` - When set to `True`, uses double buffering for offloading. + Deprecated. + retain_pinned_cpu_buffers: bool, default = `False` + If True, the pinned CPU buffers are retained after offloading + and reused for the next iteration. It is useful for cuda graphs capture. + manual_synchronization: bool, default = `False` + If True, the synchronization is done manually by the user. + Additional argument manual_controller is returned. See more in manual control section. + offload_stream: torch.cuda.Stream, default = `None` + If provided, the offload stream is used for offloading and reloading. + Otherwise, a new stream is allocated internally. It can be other than None + only if manual_synchronization is True. + + Manual synchronization + ---------- + By default, layers are offloaded/reloaded asynchronously + with respect to the current forward/backward stream with predefined synchronization, + to ensure that activation memory usage is equal to + `(num_layers - num_offloaded_layers) * T`, where `T` is the memory footprint of a layer. + + For more control over the offloading and reloading process, you can set `manual_synchronization=True`. + In this case, an additional argument, `manual_controller`, is returned. + + The `manual_controller` provides the following methods: + - `start_offload_layer(layer_id: int)` + - `release_activation_forward_gpu_memory(layer_id: int)` + - `start_reload_layer(layer_id: int)` + + If none of these methods are invoked for a given layer, that layer will not be offloaded or reloaded. + If `start_offload_layer()` is called for a layer, offload copies for that layer begin asynchronously on the offload stream. + + Since GPU activations must be kept in memory until the copy is finished, pointers to all activations are stored. + To release this memory, you need to call `release_activation_forward_gpu_memory(layer_id)`. + This method makes the current stream wait for an event recorded on the offload stream after all tensors from the layer have been offloaded. + + The `start_reload_layer()` method is used to start reloading a layer. + Each tensor reload is awaited to finish before `tensor_pop()` for that tensor is called on the current stream. + + You can provide an `offload_stream` to be used for offload and reload operations. + This allows for more detailed synchronization, such as delaying the start of offloading. + + Example: + .. code-block:: python + offload_stream = torch.cuda.Stream() + cpu_offload_context, sync_function, manual_controller = get_cpu_offload_context( + enabled=True, model_layers=num_layers, manual_synchronization=True, offload_stream=offload_stream) + + for i in range(num_layers): + with cpu_offload_context: + out[i] = layers[i].forward(inp[i]) + out[i] = sync_function(out[i]) + manual_controller.start_offload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + manual_controller.release_activation_forward_gpu_memory(i) + + for i in range(num_layers - 1, -1, -1): + manual_controller.start_reload_layer(i) + + offload_stream.synchronize() + for i in range(num_layers): + out[i].sum().backward() + + V1 code path + ---------- + If you want to use the v1 code path for offloading, + please set the environment variable NVTE_CPU_OFFLOAD_V1 to 1. """ + if NVTE_CPU_OFFLOAD_V1: + return v1_code_path.get_cpu_offload_context( + enabled=enabled, + num_layers=num_layers, + model_layers=model_layers, + offload_activations=offload_activations, + offload_weights=offload_weights, + double_buffering=double_buffering, + ) if not offload_weights and not offload_activations: raise ValueError( @@ -690,8 +753,6 @@ def get_cpu_offload_context( ) if offload_weights: - import warnings - warnings.warn( "Offloading weights is deprecated. Using offload_weights=True does not have any" " effect.", @@ -700,26 +761,100 @@ def get_cpu_offload_context( # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. if not offload_activations: - return nullcontext(), lambda x: x - - def tensor_need_offloading_checker_activations(tensor): - return hasattr(tensor, "activation_offloading") + return contextlib.nullcontext(), lambda x: x - tensor_need_offloading_checker = tensor_need_offloading_checker_activations + if TEDebugState.debug_enabled: + raise RuntimeError("CPU offload is not supported in debug mode.") - cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( - num_offload_group=num_layers, - num_model_group=model_layers, - tensor_need_offloading_checker=tensor_need_offloading_checker, - double_buffering=double_buffering, - ) + if not manual_synchronization: + assert ( + num_layers <= model_layers - 1 + ), "Cannot offload all layers without manual synchronization - last layer is not offloaded." + if num_layers == model_layers - 1: + warnings.warn( + "Offloading num_layers == model_layers - 1 is not recommended, it prevents" + " overlapping of computation and offload/reload." + ) + + assert ( + offload_stream is None or manual_synchronization + ), "offload_stream can be provided only if manual_synchronization is True" + + if manual_synchronization: + offload_synchronizer = ManualOffloadSynchronizer( + model_layers, retain_pinned_cpu_buffers, offload_stream + ) + else: + offload_synchronizer = DefaultOffloadSynchronizer( + model_layers, + num_layers, + retain_pinned_cpu_buffers, + offload_stream, + ) - def group_prefetch_offload_commit_async(tensor): - return group_prefetch_offload_commit(tensor, cpu_offload_handler) + class _CpuOffloadContext(contextlib.ContextDecorator): + def __init__(self): + self.current_layer = None + self.previous_offload_synchronizer = None + self.offload_synchronizer = offload_synchronizer + + self.inside_context = False + + def __enter__(self): + assert ( + self.inside_context is False + ), "Offloading context was entered without synchronization function being called." + self.inside_context = True + self._hooks_ctx = saved_tensors_hooks( + offload_synchronizer.push_tensor, offload_synchronizer.pop_tensor + ) + self._hooks_ctx.__enter__() + global OFFLOAD_SYNCHRONIZER + self.previous_offload_synchronizer = OFFLOAD_SYNCHRONIZER + OFFLOAD_SYNCHRONIZER = offload_synchronizer + self.current_layer = offload_synchronizer.fwd_step() + return self + + def __exit__(self, *args): + self._hooks_ctx.__exit__(*args) + global OFFLOAD_SYNCHRONIZER + OFFLOAD_SYNCHRONIZER = self.previous_offload_synchronizer + self.inside_context = False + + def synchronization_function(self, tensor): + """ + This function is used to catch the backward pass of the model. + """ + assert tensor.requires_grad is True + assert self.current_layer is not None + cur_layer = self.current_layer + assert ( + self.inside_context is False + ), "Synchronization function was called without offloading context being entered." + + def hook(_): + # offload_synchronizer.finish_part_of_bwd needs + # to be called after every backward pass - there may be + # more than one in pipeline parallelism. + torch.autograd.variable.Variable._execution_engine.queue_callback( + offload_synchronizer.finish_part_of_bwd + ) + offload_synchronizer.bwd_step(cur_layer) + + tensor.grad_fn.register_prehook(hook) + return tensor + + cpu_offload_context = _CpuOffloadContext() if enabled: + if manual_synchronization: + return ( + cpu_offload_context, + cpu_offload_context.synchronization_function, + offload_synchronizer, + ) return ( - CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), - group_prefetch_offload_commit_async, + cpu_offload_context, + cpu_offload_context.synchronization_function, ) - return nullcontext(), group_prefetch_offload_commit_async + return contextlib.nullcontext(), lambda x: x diff --git a/transformer_engine/pytorch/cpu_offload_v1.py b/transformer_engine/pytorch/cpu_offload_v1.py new file mode 100644 index 000000000..9f904864a --- /dev/null +++ b/transformer_engine/pytorch/cpu_offload_v1.py @@ -0,0 +1,743 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Functionality for CPU offloading of tensors saved for backward pass.""" +from __future__ import annotations +from contextlib import nullcontext +from typing import Any, Dict, Optional + +import torch + +from transformer_engine.debug.pytorch.debug_state import TEDebugState +from .quantized_tensor import QuantizedTensorStorage +from .tensor.float8_tensor import Float8Tensor + +__all__ = ["get_cpu_offload_context"] + +CPUOffloadEnabled = False +CPUOffloadedLayer = False + + +def mark_activation_offload(*tensors): + """Set the type of the offloading needed for a tensor.""" + if TEDebugState.debug_enabled: + raise RuntimeError("CPU offload is not supported in debug mode.") + + for tensor in tensors: + if tensor is None: + continue + if type(tensor) in [torch.Tensor, torch.nn.Parameter]: + tensor.activation_offloading = True + else: + data_tensors = tensor.get_data_tensors() + for tensor in data_tensors: + if tensor is not None: + tensor.activation_offloading = True + # This is a hack to force clear the tensor after it is offloaded. + # It is needed, because .*TensorStorage classes are saved in the ctx, + # and they contain the reference to their data tensors. + tensor.needs_force_clear = True + + +def is_cpu_offload_enabled() -> bool: + """Check if CPU offloading is currently enabled.""" + return CPUOffloadEnabled + + +def is_current_layer_offloaded() -> bool: + """Check if current layers is being offloaded.""" + return CPUOffloadedLayer + + +class CpuOffloadSavedTensorHook: + """Contex-manager that executes a pair of pack/unpack hooks for saved tensors. + + In this context, the ``on_save_for_backward`` method will be called every time + a tensor is saved for backward (this includes intermediary results saved using + :func:`~torch.autograd.function._ContextMethodMixin.save_for_backward` but + also those recorded by a PyTorch-defined operation). + + The ``on_get_saved_tensors`` method will be called when the backward function + of this op attempts to retrieve the saved tensor from context (this includes + :func: `torch.Tensor.backward()` or :func: `torch.autograd.grad()`. It takes the + as input the return value of the ``on_save_for_backward``, and is meant to return + an identical copy of the tensor being saved by ``on_save_for_backward`` in terms of + size, device and element values. + + Example: + + >>> import torch + >>> from typing import Any + >>> + >>> class DummyHook(CpuOffloadSavedTensorHook): + ... + ... def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + ... logging.info("On save", tensor) + ... return (tensor,) + ... + ... def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + ... logging.info("On get", saved_state) + ... tensor, = saved_state + ... return tensor + ... + >>> a = torch.ones(5, requires_grad=True) + >>> b = torch.ones(5, requires_grad=True) * 2 + >>> with DummyHook(): + ... y = a * b + ... + On save tensor([1., 1., 1., 1., 1.], requires_grad=True) + On save tensor([2., 2., 2., 2., 2.], grad_fn=) + >>> y.sum().backward() + On get (tensor([1., 1., 1., 1., 1.], requires_grad=True),) + On get (tensor([2., 2., 2., 2., 2.], grad_fn=),) + + """ + + def __init__(self) -> None: + self.inside_context = False + + def __enter__(self): + global CPUOffloadEnabled + CPUOffloadEnabled = True + + self.inside_context = True + torch._C._autograd._push_saved_tensors_default_hooks( + self.on_save_for_backward, self.on_get_saved_tensor + ) + + def __exit__(self, *args: Any): + global CPUOffloadEnabled + CPUOffloadEnabled = False + + self.inside_context = False + torch._C._autograd._pop_saved_tensors_default_hooks() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + """On save for backward.""" + raise NotImplementedError( + "`on_save_for_backward: Callable[[torch.Tensor], Any]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + """On get saved tensor.""" + raise NotImplementedError( + "`on_get_saved_tensors: Callable[[Any], torch.Tensor]`" + "is not implemented in CpuOffloadHook class. Inherit " + "this class and implement your custom hooks" + ) + + +class CpuOffloadHookWithOffloadHandler(CpuOffloadSavedTensorHook): + """Context-manager that offloads/recovers tensors through an offload hander. + + The hook just offloads/recovers the tensor object to the handler through `tensor_push` + and `tensor_pop` interface. How the offload-handler manages the offloading, recovering + or prefetching timing is transparent to this hook. + """ + + def __init__( + self, + offload_handler: OffloadHandler, + handler_extra_kwargs: Optional[Dict[str, Any]] = None, + debug: bool = False, + ) -> None: + if handler_extra_kwargs is None: + handler_extra_kwargs = {} + self.debug: bool = debug + self.offload_handler: OffloadHandler = offload_handler + self.handler_extra_kwargs: Dict[str, Any] = handler_extra_kwargs + super().__init__() + + def on_save_for_backward(self, tensor: torch.Tensor) -> Any: + retrieve_identifier = self.offload_handler.tensor_push(tensor, **self.handler_extra_kwargs) + return retrieve_identifier + + def on_get_saved_tensor(self, saved_state: Any) -> torch.Tensor: + tensor = self.offload_handler.tensor_pop(saved_state, **self.handler_extra_kwargs) + return tensor + + +class OffloadHandler: + """A base class for CPU offload-handler.""" + + def __init__(self) -> None: + pass + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + """Tensor push.""" + raise NotImplementedError( + "`tensor_push is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_push." + ) + + def tensor_pop(self, tensor_tag: Any, **kwargs): + """Tensor pop.""" + raise NotImplementedError( + "`tensor_pop is not implented in OffloadHandler class. " + "Inherit this class and implement your custom tensor_pop." + ) + + +class GroupCommitFunction(torch.autograd.Function): + """this is a dummy op with output identical to input. + However, it is necessary for marking a timepoint for offload handler to + accomplish all synchronizations. Implementing it as a function is necessary + because we need to actions in both forward and backward. + """ + + @staticmethod + def forward(ctx, tensor, cpu_offload_handler): + # pylint: disable=missing-function-docstring + cpu_offload_handler.on_group_commit_forward() + ctx.cpu_offload_handler = cpu_offload_handler + # return the identical tensor + return tensor + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + cpu_offload_handler = ctx.cpu_offload_handler + cpu_offload_handler.on_group_commit_backward() + return grad_output, None + + +group_prefetch_offload_commit = GroupCommitFunction.apply + + +class SynchronizedGroupOffloadHandler(OffloadHandler): + """Offload Handler that offloads/reloads in a synchronized way. + The device-to-host and host-to-device copying happen in the same stream + as the computation kernels, thus the copying will block computation. + """ + + def __init__( + self, num_offload_group, tensor_need_offloading_checker=(lambda _: True), debug=False + ) -> None: + super().__init__() + + self.num_offload_group = num_offload_group + self.tensor_need_offloading_checker = tensor_need_offloading_checker + self.debug = debug + + self.groupid_reset() + + def groupid_reset(self): + """Groupid reset.""" + # Data structures to label saved tensors and book-keep their cpu copies. + # Currently, on push, create a new cpu tensor and copies; on pop, copies + # the tensor back to gpu and deletes the cpu tensor. + # These will increment whenever `group_commit()` is invoked + self.current_group, self.tensor_count_current_group = (0, 0) + self.torch_tensor_count = 0 + self.tensor_tag_to_state = {} + + def on_group_commit_forward(self): + """On group commit forward.""" + # finishing up with updating current group and tensor count + self.current_group += 1 # increment + self.tensor_count_current_group = 0 # reset + + def on_group_commit_backward(self): + """On group commit backward.""" + self.current_group -= 1 + assert self.current_group >= 0 + + @staticmethod + def offload(src_tensor, pin_memory=True): + """Offload.""" + + cpu_backup = torch.empty( + src_tensor.size(), + dtype=src_tensor.dtype, + layout=src_tensor.layout, + device="cpu", + pin_memory=pin_memory, + ) + + cpu_backup.copy_(src_tensor, non_blocking=pin_memory) + state = (src_tensor.device, cpu_backup) + return state + + @staticmethod + def reload(state, non_blocking=None, copy_buffer=None): + """Reload.""" + dev, cpu_backup = state + if non_blocking is None: + non_blocking = cpu_backup.is_pinned() + + if copy_buffer is None: + return cpu_backup.to(dev, non_blocking=non_blocking) + + assert cpu_backup.size() == copy_buffer.size(), "Can't copy two buffers of different sizes!" + + copy_buffer.copy_(cpu_backup, non_blocking=non_blocking) + + return copy_buffer + + def tensor_push(self, tensor: torch.Tensor, **kwargs): + """Tensor push.""" + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + assert tensor_tag not in self.tensor_tag_to_state + if self.current_group < self.num_offload_group and self.tensor_need_offloading_checker( + tensor + ): + state = SynchronizedGroupOffloadHandler.offload(tensor) + self.tensor_tag_to_state[tensor_tag] = state + else: + # will be offloaded together after group commit + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + assert tensor_tag in self.tensor_tag_to_state + state = self.tensor_tag_to_state.pop(tensor_tag) + if isinstance(state, tuple): + tensor = SynchronizedGroupOffloadHandler.reload(state) + else: + tensor = state + return tensor + + +class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): + """Compared to synchronize, this uses more memory because of the buffer but + achieves better performance due to the overlapping. D2h and h2d copying are + completely hidden behind computation if computation time of a layer is longer + than host-device communication time. Bulk offloading with delay and bulk reloading + with prefetch are implemented.""" + + def __init__( + self, + num_offload_group, # must be <= actual number of groups (number of commits) + num_model_group, + tensor_need_offloading_checker=(lambda t: True), + double_buffering=False, + debug=False, + ) -> None: + super().__init__( + num_offload_group=num_offload_group, + tensor_need_offloading_checker=tensor_need_offloading_checker, + debug=debug, + ) + # Number of layers in the model + self.num_layers = num_model_group + # Data Structure to maintain reference to activation tensors + self.tensor_tag_to_buf = {} + # Data structure to hold the FP8/MXFP8 tensor objects + self.fp8_tensor_object_map = {} + self.float8_transpose_cache_valid = {} + self.dereferencing_list = [] + # Tracking the number of layers offloaded + self.offloaded_group_count = 0 + # Core data structure that decides the window for offloading + self.layer_window_map = {} + + # Data structures fo double buffered reloading + self.double_buffering = double_buffering + self.reload_double_buffer = [[], []] + self.double_buffer_created = False + + # Logic to make offloading load balance across computation + # for optimal CPU/GPU interconnect usage + constant = 0 + for i in range(self.num_offload_group): + self.layer_window_map[i] = ((self.num_layers // self.num_offload_group) * (i + 1)) - 1 + if i < (self.num_layers % self.num_offload_group): + self.layer_window_map[i] += i + 1 + constant = i + 1 + else: + self.layer_window_map[i] += constant + + # allocate streams and events for synchronization + self.d2h_stream = torch.cuda.Stream() + self.h2d_stream = torch.cuda.Stream() + + def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: + global CPUOffloadedLayer + + torch_stray_tensor = isinstance( + tensor, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), + ) + + is_quantized_tensor = isinstance(tensor, QuantizedTensorStorage) + + if not torch_stray_tensor: + + # obtain a unique tensor tag + tensor_tag = (self.current_group, self.tensor_count_current_group) + self.tensor_count_current_group += 1 + + assert tensor_tag not in self.tensor_tag_to_state + + if is_quantized_tensor: + tensor_list, _ = tensor.prepare_for_saving() + + self.tensor_tag_to_state[tensor_tag] = [] + self.tensor_tag_to_buf[tensor_tag] = [] + + # Added support for de-duplicating FP8 param tensors + for _, value in self.fp8_tensor_object_map.items(): + if tensor is value: + self.dereferencing_list.append(tensor_tag) + break + + self.fp8_tensor_object_map[tensor_tag] = tensor + if isinstance(tensor, Float8Tensor): + self.float8_transpose_cache_valid[tensor_tag] = getattr( + tensor, "_transpose_invalid" + ) + else: + tensor_list = [tensor] + + for t in tensor_list: + if is_quantized_tensor: + self.tensor_tag_to_state[tensor_tag].append(t) + else: + self.tensor_tag_to_state[tensor_tag] = t + + if ( + self.current_group < self.num_offload_group + and self.tensor_need_offloading_checker(t) + ): + if is_quantized_tensor: + self.tensor_tag_to_buf[tensor_tag].append(t) + # Need to clear the internal data reference for the quantized tensors + tensor.clear() + else: + self.tensor_tag_to_buf[tensor_tag] = t + + # Needed to differentiate non offloaded layer's attention + # QKV layout of attention of non-offloaded layer needs + # to be modified while reloading + CPUOffloadedLayer = True + else: + tensor_tag = (-1, self.torch_tensor_count) + self.torch_tensor_count += 1 + self.tensor_tag_to_state[tensor_tag] = tensor + + return tensor_tag + + def tensor_pop(self, tensor_tag, **kwargs): + """Tensor pop.""" + global CPUOffloadedLayer + + assert tensor_tag in self.tensor_tag_to_state + tensor = self.tensor_tag_to_state.pop(tensor_tag) + + # Handling the quantized tensor case specially here + if isinstance(tensor, list): + # If it's a duplicated tensor, we don't need to locally + # write back a tensor as it would already be written + if tensor_tag in self.dereferencing_list: + self.dereferencing_list.remove(tensor_tag) + else: + self.fp8_tensor_object_map[tensor_tag].restore_from_saved(tensor) + tensor = self.fp8_tensor_object_map.pop(tensor_tag) + + if self.double_buffering: + tensor._do_not_clear = True + + self.tensor_tag_to_buf.pop(tensor_tag, None) + # the tensor should have been copied back in on_group_commit_backward() + # which invokes bulk_reload_group. + assert not isinstance(tensor, tuple) + return tensor + + def bulk_offload_group(self, group_to_offload): + """Bulk offload group.""" + with torch.cuda.stream(self.d2h_stream): + for tensor_tag, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_tag + if group_id == group_to_offload: + assert not isinstance(state, tuple) + + is_quantized_tensor = isinstance(state, list) + + if is_quantized_tensor: + tensor_list = state + self.tensor_tag_to_state[tensor_tag] = [] + else: + tensor_list = [state] + + for tensor_on_device in tensor_list: + # `tensor_offloaded` is a hacky way of dealing with columnwise-only + # quantized tensors for CPU offloading. The complication is due to + # the `rowwise_data` being `None`. The offloading checker incorrectly + # returns `False` and the entire `state` ([None, columnwise_tensor]) + # is added to the tensor tag state dict. A better design would change + # how quantized tensors are kept track of in the offload handler. + # Currently at every stage it is ensured that a quantized tensor is a + # list whereas a non-quantized tensor is standalone object, which is + # not good! TODO(@sanandaraj5597) + tensor_offloaded = False + # if offload, return the reference to cpu copy + if self.tensor_need_offloading_checker(tensor_on_device): + tensor_offloaded = True + state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) + if is_quantized_tensor: + if tensor_offloaded: + self.tensor_tag_to_state[tensor_tag].append(state) + else: + self.tensor_tag_to_state[tensor_tag].append(tensor_on_device) + else: + self.tensor_tag_to_state[tensor_tag] = state + + def synchronize_on_group_commit_forward(self, current_group): + """Synchronize on group commit forward.""" + global CPUOffloadedLayer + + # For the first group, kickstart the offload after we have + # the first compute completion + if current_group == 0: + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + + if not self.double_buffer_created: + # Creating the first copy of double buffer for tensors that are offloaded + for tensor_tag, buf in self.tensor_tag_to_buf.items(): + if isinstance(buf, list): + for b in buf: + self.reload_double_buffer[0].append( + torch.empty_like(b) if self.double_buffering else None + ) + else: + self.reload_double_buffer[0].append( + torch.empty_like(buf) if self.double_buffering else None + ) + + self.bulk_offload_group(current_group) + + # Window map data structure helps us synchronize based on number + # of layers offloaded + if self.layer_window_map[self.offloaded_group_count] == current_group: + + # Stream synchronization both ways + self.d2h_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.d2h_stream) + + # Time to free the activation memory after usage + for tensor_tag, tensor_buf in self.tensor_tag_to_buf.items(): + if tensor_tag[0] == self.offloaded_group_count: + if hasattr(tensor_buf, "needs_force_clear"): + # Need to clear activation tensor - sometimes references persist in the code. + # This is the case for example with the Float8TensorStorage class, + # which is saved directly inside the ctx while its internal tensors are + # saved inside save_for_backward. + tensor_buf.data = torch.Tensor() + # Release the pointer to the tensor + self.tensor_tag_to_buf[tensor_tag] = None + + # Time to offload the next group + if self.offloaded_group_count < (self.num_offload_group - 1): + self.bulk_offload_group(self.offloaded_group_count + 1) + + # Increment the offload group count to keep track + self.offloaded_group_count += 1 + + if current_group == (self.num_offload_group - 1): + CPUOffloadedLayer = False + + if not self.double_buffer_created: + # Creating second copy of double buffer for tensors that are offloaded + if current_group == (self.num_layers - 1): + for buf in self.reload_double_buffer[0]: + self.reload_double_buffer[1].append( + torch.empty_like(buf) if self.double_buffering else None + ) + self.double_buffer_created = True + + def on_group_commit_forward(self): + """This function will cause host device synchronization""" + # handle synchronization events + self.synchronize_on_group_commit_forward(self.current_group) + + super().on_group_commit_forward() + + def bulk_reload_group(self, group_to_reload): + """Bulk reload group.""" + assert group_to_reload < self.num_offload_group + + buffer_idx = 0 + double_buffer_idx = group_to_reload % 2 + + main_stream = torch.cuda.current_stream() + + with torch.cuda.stream(self.h2d_stream): + # move back tensors + for tensor_label, state in self.tensor_tag_to_state.items(): + group_id, _ = tensor_label + if group_id == group_to_reload: + + if isinstance(state, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state[1], device=torch.cuda.current_device() + ) + + recovered_tensor = SynchronizedGroupOffloadHandler.reload( + state, True, reload_buffer + ) + buffer_idx = buffer_idx + 1 + self.tensor_tag_to_state[tensor_label] = recovered_tensor + elif isinstance(state, list): + tensor_list = [] + for state_tuple in state: + + if isinstance(state_tuple, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][ + buffer_idx + ] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state_tuple[1], device=torch.cuda.current_device() + ) + + tensor_list.append( + SynchronizedGroupOffloadHandler.reload( + state_tuple, + True, + reload_buffer, + ) + ) + buffer_idx = buffer_idx + 1 + else: + tensor_list.append(state_tuple) + + # No need to write back the duplicated tensor againn + # to the same location, this check ensures that + if tensor_label in self.dereferencing_list: + self.dereferencing_list.remove(tensor_label) + else: + _ = self.fp8_tensor_object_map[tensor_label].restore_from_saved( + tensor_list + ) + + if isinstance(self.fp8_tensor_object_map[tensor_label], Float8Tensor): + self.fp8_tensor_object_map[tensor_label]._transpose_invalid = ( + self.float8_transpose_cache_valid.pop(tensor_label) + ) + + self.tensor_tag_to_state[tensor_label] = self.fp8_tensor_object_map.pop( + tensor_label + ) + + def on_group_commit_backward(self): + # first decrement the current group. + # after last commit in forward, the group will +1; in backward it -1. + # Finally it should be decremented to 0. + self.current_group -= 1 + assert self.current_group >= 0 + + # Layer window data structure helps us to reload at right times + if self.layer_window_map[self.offloaded_group_count - 1] == self.current_group: + + # Stream synchronization both ways + self.h2d_stream.wait_stream(torch.cuda.current_stream()) + torch.cuda.current_stream().wait_stream(self.h2d_stream) + + # Time to reload the next group + self.bulk_reload_group(self.offloaded_group_count - 1) + + # Decrease the offloading group counter + self.offloaded_group_count -= 1 if self.offloaded_group_count > 1 else 0 + + # Last group computation needs to wait till all the reloads complete + if self.current_group == 0: + torch.cuda.current_stream().wait_stream(self.h2d_stream) + self.offloaded_group_count = 0 + + +def get_cpu_offload_context( + enabled: bool = False, + num_layers: int = 1, + model_layers: int = 1, + offload_activations: bool = True, + offload_weights: bool = False, + double_buffering: bool = False, +): + """ + This function returns the CPU Offload context and the synchronizer function that needs to be + used after every transformer layer. Returns `nullcontext()` if offloading is not enabled. + + Usage: + + .. code-block:: python + + cpu_offload_context, cpu_offload_synchronizer = get_cpu_offload_context(enabled=True) + + with cpu_offload_context: + te_layer.forward(inp_tensor) + cpu_offload_synchronizer() + + Parameters + ---------- + enabled: bool, default = `False` + When set to True, CPU Offloading functionality is enabled. + num_layers: int, default = 1 + Determines the number of transformer layers + you want to offload activations/weights for. + model_layers: int, default = 1 + Number of layers in the model that will be used under this context. + offload_activations: bool, default = `True` + When set to `True`, offloads the activations for the TE layer. + offload_weights: bool, default = `True` + When set to `True`, offloads the weights for the TE layer. + double_buffering: bool, default = `False` + When set to `True`, uses double buffering for offloading. + + """ + + if not offload_weights and not offload_activations: + raise ValueError( + "CPU Offloading is enabled while it is not " + "mentioned what to offload (weights/activations)" + ) + + if offload_weights: + import warnings + + warnings.warn( + "Offloading weights is deprecated. Using offload_weights=True does not have any" + " effect.", + DeprecationWarning, + ) + + # Weights offloading is deprecated but we maintain backward compatibility by doing nothing. + if not offload_activations: + return nullcontext(), lambda x: x + + def tensor_need_offloading_checker_activations(tensor): + return hasattr(tensor, "activation_offloading") + + tensor_need_offloading_checker = tensor_need_offloading_checker_activations + + cpu_offload_handler = AsyncDoubleBufferGroupOffloadHandler( + num_offload_group=num_layers, + num_model_group=model_layers, + tensor_need_offloading_checker=tensor_need_offloading_checker, + double_buffering=double_buffering, + ) + + def group_prefetch_offload_commit_async(tensor): + return group_prefetch_offload_commit(tensor, cpu_offload_handler) + + if enabled: + return ( + CpuOffloadHookWithOffloadHandler(offload_handler=cpu_offload_handler), + group_prefetch_offload_commit_async, + ) + return nullcontext(), group_prefetch_offload_commit_async diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index aa6602401..59f57743b 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -12,12 +12,26 @@ #include "pybind.h" #include "transformer_engine/transformer_engine.h" -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM #include "common/common.h" #endif namespace transformer_engine::pytorch { +/*! convert fp4 data shape back to original shape */ +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose) { + std::vector ret; + size_t start_idx = (transpose) ? 1 : 0; + for (size_t i = start_idx; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() * 2); + if (transpose) { + ret.push_back(shape.front()); + } + return ret; +} + std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { @@ -182,8 +196,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor( const std::vector meta_shape{1}; ret.set_amax(amax_ptr, DType::kFloat32, meta_shape); ret.set_scale(scale_ptr, DType::kFloat32, meta_shape); - auto scale_inv_dtype = - (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32; + auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 + : (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3 + : DType::kFloat32; ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape); ret.set_columnwise_scale_inv(columnwise_scale_inv_ptr, scale_inv_dtype, columnwise_scale_inv_shape); @@ -297,7 +312,7 @@ size_t roundup(const size_t value, const size_t multiple) { return ((value + multiple - 1) / multiple) * multiple; } -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM inline bool nvte_use_atomic_amax() { const char *env_p = std::getenv("NVTE_USE_ATOMIC_AMAX"); @@ -320,4 +335,20 @@ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor) { #endif +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr) { + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, + arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, + at::cuda::getCurrentCUDAStream()); + }); +} + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread) { + at::PhiloxCudaState philox_args; + std::lock_guard lock(gen->mutex_); + philox_args = gen->philox_cuda_state(elts_per_thread); + return philox_args; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 07384413d..fd83f20d4 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -37,6 +37,7 @@ #include #include #include +#include #include #include #include @@ -204,20 +205,25 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; - /*! @brief Construct a high precision tensor giving it this quantizer's amax - - Note: this member function also zeros out the amax, as it is meant to be used in conjunction with - a kernel computing the amax, which might expect the amax to be initialized to zero + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. */ - std::pair create_hp_tensor_with_amax(const std::vector& shape, - DType dtype); + std::pair create_unquantized_tensor_with_amax( + const std::vector& shape, DType dtype, std::optional data = std::nullopt); std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; - /*! @brief Convert to a quantized data format avoiding amax computation */ + /*! @brief Quantize to FP8, skipping local amax computation + * + * The quantizer's amax pointer is assumed to already hold the local + * amax. The amax may still be reduced across the amax reduction + * group. + */ void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt); @@ -287,6 +293,62 @@ class MXFP8Quantizer : public Quantizer { std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; +#ifndef __HIP_PLATFORM_AMD__ +class NVFP4Quantizer : public Quantizer { + public: + // fp4 dtype + DType dtype; + // amax reduction for low precision FP4 AG + bool with_amax_reduction; + c10::intrusive_ptr amax_reduction_group; + // random hadamard transform + bool with_rht; + bool with_post_rht_amax; + // 2D block scaling + bool with_2d_quantization; + bool stochastic_rounding; + + int rht_matrix_random_sign_mask_t; + at::Tensor rht_matrix; + + explicit NVFP4Quantizer(const py::handle& quantizer); + + NVTEScalingMode get_scaling_mode() const override { return NVTE_NVFP4_1D_SCALING; } + + void set_quantization_params(TensorWrapper* tensor) const override; + + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer + * + * The amax is zeroed out. Most TE kernels that output amax expect + * amax to be initialized to zero. + */ + std::pair create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype); + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Quantize to NVFP4, skipping local amax computation + * + * The input tensor's amax pointer is assumed to already hold the + * local amax. The amax may still be reduced across the amax + * reduction group. + */ + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out); + + std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; + + private: + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, bool compute_amax); +}; +#endif // #ifndef __HIP_PLATFORM_AMD__ + std::unique_ptr convert_quantizer(py::handle quantizer); std::vector getTensorShape(const at::Tensor& t); @@ -448,6 +510,15 @@ NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); #ifdef __HIP_PLATFORM_AMD__ at::Tensor allocate_amax_workspace(const TensorWrapper& input_tensor); #endif + +std::vector convert_shape_back_from_fp4(const std::vector& shape, bool transpose); + +// unpack the PhiloxCudaState into CUDA tensor +void philox_unpack(at::PhiloxCudaState arg, int64_t* rng_state_ptr); + +// extract PhiloxCudaState from CUDA random number generator +at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl* gen, size_t elts_per_thread); + } // namespace transformer_engine::pytorch namespace std { diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9b527b161..73f273ef8 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -83,32 +83,40 @@ std::tuple moe_unpermute_bwd(at::Tensor input_bwd, at::T NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right); + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph); + +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data); std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread); + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph); std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer); + py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph); at::Tensor fa_prepare_fwd(at::Tensor qkvi); at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v); @@ -207,6 +215,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha); + +py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer, + float limit, float alpha); /*************************************************************************************************** * LayerNorm **************************************************************************************************/ @@ -344,6 +356,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 8b0607c9e..205605312 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -1,188 +1,347 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ - #include "../extensions.h" #include "common.h" #include "pybind.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { + +namespace { +using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t); +using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t); -template -py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { +template +py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1, + Args&&... args) { init_extension(); // Input tensor auto input_tensor = input.contiguous(); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct output tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape = input_cpp.shape(); + const auto input_shape = input_nvte.shape(); std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); output_shape.back() /= shape_divisor; auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + auto [out_nvte, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); - // Compute activation + // Choose implementation + enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation directly - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation in high-precision fused together with amax, then quantize. - - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); - } else { - // Compute activation in high-precision, then quantize - - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE( - { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); - quantizer_cpp->quantize(temp_cpp, out_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifdef USE_ROCM + } +#else + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } + } +#endif + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); + quantizer_cpp->quantize(temp_nvte, out_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation directly + { + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), out_nvte.data(), stream); + } + }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; +#ifndef USE_ROCM + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (act_func == nullptr) { + act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward(args)..., + stream); + } else { + act_func(input_nvte.data(), temp_nvte.data(), stream); + } + }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); + } + break; +#endif + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return out_py; } -template +template py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, - py::handle quantizer) { + py::handle quantizer, Args&&... args) { init_extension(); // Grad output and input tensors auto grad_output_tensor = grad_output.contiguous(); auto input_tensor = input.contiguous(); - const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); - const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + const TensorWrapper& grad_output_nvte = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_nvte = makeTransformerEngineTensor(input_tensor); // Construct grad input tensor auto quantizer_cpp = convert_quantizer(quantizer); - const auto input_shape_te = input_cpp.shape(); + const auto input_shape_te = input_nvte.shape(); const std::vector input_shape(input_shape_te.data, input_shape_te.data + input_shape_te.ndim); auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); - auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); - // Compute activation backward + // Choose implementation + enum class Impl { UNFUSED, FULLY_FUSED, FUSED_ACTIVATION_AMAX_FP8, FUSED_ACTIVATION_AMAX_NVFP4 }; + Impl impl = Impl::UNFUSED; if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr())) { - // Compute activation backward directly - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); + impl = Impl::FULLY_FUSED; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // Compute activation backward in high-precision fused together with amax, then quantize. - auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); - } else { - // Compute activation backward in high-precision, then quantize - auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); - NVTE_SCOPED_GIL_RELEASE({ - dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), - at::cuda::getCurrentCUDAStream()); - }); - quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + impl = Impl::FUSED_ACTIVATION_AMAX_FP8; +#ifdef USE_ROCM + } +#else + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; + } + } +#endif + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Compute activation backward in high precision, then quantize + { + auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); + quantizer_cpp->quantize(temp_nvte, grad_input_nvte); + } + break; + case Impl::FULLY_FUSED: + // Compute activation backward directly + { + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); + } + }); + } + break; + case Impl::FUSED_ACTIVATION_AMAX_FP8: + // Compute activation and amax in high precision, then quantize to FP8 + { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + auto [temp_nvte, _] = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; +#ifndef USE_ROCM + case Impl::FUSED_ACTIVATION_AMAX_NVFP4: + // Compute activation and amax in high precision, then quantize to NVFP4 + { + auto nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + auto [temp_nvte, _] = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + if constexpr (dact_func == nullptr) { + dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), + std::forward(args)..., stream); + } else { + dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); + } + }); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + } + break; +#endif + default: + NVTE_ERROR("Invalid activation implementation (", static_cast(impl), ")"); } return grad_input_py; } +} // namespace -/* GELU and variants*/ +/* GELU and variants */ py::object gelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } -/* ReLU and variants*/ +/* ReLU and variants */ py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object reglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object sreglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } - -/* Silu and variants*/ +/* Silu and variants */ py::object silu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); + return activation_helper(input, quantizer); } py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); } py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); + return activation_helper(input, quantizer, 2); } py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); + return dactivation_helper(grad, input, quantizer); +} + +/* clamped functions */ +py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) { + return activation_helper(input, quantizer, 2, limit, alpha); } -} // namespace transformer_engine::pytorch + +py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer, + float limit, float alpha) { + return dactivation_helper(grad, input, quantizer, limit, alpha); +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 064da8a67..d1dcf68c3 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -163,6 +163,7 @@ std::tuple fused_qkv_rope_forward( } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank) { @@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); + } + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); @@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, - max_s, b, h, d, d2, stride_t, + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; @@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, - h, d, d2, stride_s, stride_b, stride_h, stride_d, - at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; } diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 6d835a5c9..d51aef406 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -35,22 +35,6 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s { nvte_memset(base_ptr, 0, total_bytes, at::cuda::getCurrentCUDAStream()); }); } -void unpack(at::PhiloxCudaState arg, int64_t *rng_state_ptr) { - NVTE_SCOPED_GIL_RELEASE({ - nvte_extract_seed_and_offset(rng_state_ptr, arg.captured_, arg.seed_.ptr, arg.seed_.val, - arg.offset_.ptr, arg.offset_.val, arg.offset_intragraph_, - at::cuda::getCurrentCUDAStream()); - }); -} - -// extract PhiloxCudaState from CUDA random number generator -at::PhiloxCudaState init_philox_state(at::CUDAGeneratorImpl *gen, size_t elts_per_thread) { - at::PhiloxCudaState philox_args; - std::lock_guard lock(gen->mutex_); - philox_args = gen->philox_cuda_state(elts_per_thread); - return philox_args; -} - } // namespace namespace transformer_engine::pytorch { @@ -58,66 +42,96 @@ namespace transformer_engine::pytorch { // get the fused attention backend NVTE_Fused_Attn_Backend get_fused_attn_backend( bool is_training, const DType q_dtype, const DType kv_dtype, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, float p_dropout, size_t num_attn_heads, - size_t num_gqa_groups, size_t max_seqlen_q, size_t max_seqlen_kv, size_t head_dim_qk, - size_t head_dim_v, int64_t window_size_left, int64_t window_size_right) { + NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q, + size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left, + int64_t window_size_right, bool return_max_logit, bool cuda_graph) { NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( is_training, static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, - bias_type, attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, - max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right); + bias_type, attn_mask_type, softmax_type, p_dropout, num_attn_heads, num_gqa_groups, + max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, window_size_left, window_size_right, + return_max_logit, cuda_graph); return fused_attention_backend; } +// helper function for S and dP quantizers +std::pair quantizer_helper(py::handle quantizer, + const std::vector &shape, DType dtype, + bool create_hp_tensor_for_cs, + std::optional data) { + std::unique_ptr T_quantizer = convert_quantizer(quantizer); + TensorWrapper te_T; + py::object py_T; + if (quantizer.is_none()) { + // high precision + auto *none_quantizer = dynamic_cast(T_quantizer.get()); + if (data.has_value()) { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = none_quantizer->create_tensor(shape, dtype); + } + } else if (detail::IsFloat8Quantizers(quantizer.ptr())) { + // delayed scaling; this helps initialize scale_inv + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + std::tie(te_T, py_T) = + T_quantizer_fp8->create_tensor(shape, dtype, data, std::nullopt, std::nullopt); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // current scaling + auto *T_quantizer_fp8 = dynamic_cast(T_quantizer.get()); + if (create_hp_tensor_for_cs) { + if (data.has_value()) { + std::tie(te_T, py_T) = + T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype, data.value()); + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_unquantized_tensor_with_amax(shape, dtype); + } + } else { + std::tie(te_T, py_T) = T_quantizer_fp8->create_tensor(shape, dtype); + NVTE_CHECK( + !data.has_value(), + "Float8CurrentScalingQuantizer::create_tensor() does not take data tensor as input!"); + } + } + return {std::move(te_T), std::move(py_T)}; +} + // fused attention FWD with separate Q, K and V tensors std::vector fused_attn_fwd( size_t max_seqlen_q, size_t max_seqlen_kv, bool is_training, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type attn_mask_type, const std::vector window_size, - const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, - const py::handle K, const py::handle V, const at::ScalarType fake_dtype, - const std::optional cu_seqlens_q_padded, + NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, + const std::vector window_size, const at::Tensor cu_seqlens_q, + const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, + const at::ScalarType fake_dtype, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, const std::optional page_table_k, const std::optional page_table_v, py::handle s_quantizer, py::handle o_quantizer, const std::optional Bias, - const std::optional rng_gen, size_t rng_elts_per_thread) { - TensorWrapper te_Q, te_K, te_V, te_O, te_S; - + const std::optional SoftmaxOffset, const std::optional rng_gen, + size_t rng_elts_per_thread, bool return_max_logit, bool cuda_graph) { auto none = py::none(); - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); + // create QKV tensor wrappers + TensorWrapper te_Q, te_K, te_V; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); - - // If qkv has FP8 dtype, fake_dtype_te is equal to the fake dtype of q, k, v - needed since torch do not have fp8 types. const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + // create S tensor + TensorWrapper te_S; + py::object py_S; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + + // create O tensor + TensorWrapper te_O; + py::object py_O; + std::unique_ptr O_quantizer = convert_quantizer(o_quantizer); std::vector q_shape = convertShape(te_Q.shape()); - std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); - auto options = torch::TensorOptions().dtype(GetATenDType(qkv_type)).device(torch::kCUDA); - // create output tensor O - auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; - py::object o_python, s_python; - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - // Initialize FP8 tensor with scale-inverse - auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - } - auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); + std::tie(te_O, py_O) = quantizer_helper(o_quantizer, o_shape, fake_dtype_te, true, std::nullopt); // construct NVTE tensors TensorWrapper te_Bias; @@ -128,11 +142,12 @@ std::vector fused_attn_fwd( // FP8 auto h = q_shape[q_shape.size() - 2]; auto d = q_shape[q_shape.size() - 1]; - if (set_zero && ((h * d) % block_size == 0) && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - te_O.zero_(at::cuda::getCurrentCUDAStream()); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if ((h * d) % block_size == 0) { + mha_fill(te_O, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + te_O.zero_(at::cuda::getCurrentCUDAStream()); + } } } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { @@ -181,12 +196,23 @@ std::vector fused_attn_fwd( DType::kInt32, nullptr, nullptr, nullptr); } + // softmax offset + TensorWrapper te_SoftmaxOffset; + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + auto SoftmaxOffset_sizes = SoftmaxOffset.value().sizes().vec(); + std::vector SoftmaxOffset_shape{SoftmaxOffset_sizes.begin(), SoftmaxOffset_sizes.end()}; + te_SoftmaxOffset = + makeTransformerEngineTensor(SoftmaxOffset.value().data_ptr(), SoftmaxOffset_shape, + DType::kFloat32, nullptr, nullptr, nullptr); + } + // extract rng seed and offset auto gen = at::get_generator_or_default( rng_gen, at::cuda::detail::getDefaultCUDAGenerator()); at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); - auto rng_state = torch::empty({2}, options.dtype(torch::kInt64)); - unpack(philox_args, static_cast(rng_state.data_ptr())); + auto options = torch::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, options); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); auto te_rng_state = makeTransformerEngineTensor(rng_state); // create auxiliary output tensors @@ -199,12 +225,13 @@ std::vector fused_attn_fwd( // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace and auxiliary output tensors @@ -214,53 +241,57 @@ std::vector fused_attn_fwd( // output_tensors = [O, nvte_aux_tensor_pack.tensors] std::vector output_tensors; - output_tensors.push_back(o_python); - for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) { - // allocate memory for nvte_aux_tensor_pack.tensors - at::Tensor output_tensor; - if (nvte_aux_tensor_pack.size >= 2) { - if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { - if (i < nvte_aux_tensor_pack.size - 2) { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } else if (i == nvte_aux_tensor_pack.size - 2) { - output_tensor = rng_state; - } else if (i == nvte_aux_tensor_pack.size - 1) { - output_tensor = Bias.value(); - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = - (i < nvte_aux_tensor_pack.size - 1) - ? allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false) - : rng_state; - } - } else { - NVTEShape temp_shape = nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i]); - output_tensor = allocateSpace( - nvte_shape_to_vector(temp_shape), - static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); - } + output_tensors.push_back(py_O); + auto set_tensor_param = [&](size_t i, const at::Tensor &output_tensor) { output_tensors.push_back(py::cast(output_tensor)); NVTEBasicTensor temp_data = {output_tensor.data_ptr(), nvte_tensor_type(nvte_aux_tensor_pack.tensors[i]), nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])}; nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); + }; + // allocate memory for nvte_aux_tensor_pack.tensors + // f16_max512 : S [b, h, sq, skv] + // f16_arbitrary: + // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] + size_t i = 0; + at::Tensor output_tensor; + // intermediate softmax tensor, S or M + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor + if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + output_tensor = + allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), + static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); + set_tensor_param(i++, output_tensor); + } + // rng_state + if (i < nvte_aux_tensor_pack.size) { + set_tensor_param(i++, rng_state); + } + // bias (optional) + if ((bias_type != NVTE_NO_BIAS) && (bias_type != NVTE_ALIBI) && (Bias.has_value())) { + set_tensor_param(i++, Bias.value()); + } + // softmax_offset (optional) + if ((softmax_type != NVTE_VANILLA_SOFTMAX) && (SoftmaxOffset.has_value())) { + set_tensor_param(i++, SoftmaxOffset.value()); } // execute the kernel NVTE_SCOPED_GIL_RELEASE({ nvte_fused_attn_fwd( - te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_S.data(), te_O.data(), - &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_Q.data(), te_K.data(), te_V.data(), te_Bias.data(), te_SoftmaxOffset.data(), te_S.data(), + te_O.data(), &nvte_aux_tensor_pack, te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), te_page_table_k.data(), te_page_table_v.data(), te_rng_state.data(), max_seqlen_q, max_seqlen_kv, is_training, - attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, window_size[0], - window_size[1], workspace.data(), at::cuda::getCurrentCUDAStream()); + return_max_logit, cuda_graph, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], workspace.data(), + at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers, but not allocated memory @@ -274,58 +305,53 @@ std::vector fused_attn_fwd( std::vector fused_attn_bwd( size_t max_seqlen_q, size_t max_seqlen_kv, float attn_scale, float p_dropout, bool set_zero, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, - const std::vector window_size, bool deterministic, const at::Tensor cu_seqlens_q, - const at::Tensor cu_seqlens_kv, const py::handle Q, const py::handle K, const py::handle V, - const py::handle O, const py::handle dO, const at::ScalarType fake_dtype, const DType dqkv_type, + NVTE_Softmax_Type softmax_type, const std::vector window_size, bool deterministic, + const at::Tensor cu_seqlens_q, const at::Tensor cu_seqlens_kv, const py::handle Q, + const py::handle K, const py::handle V, const py::handle O, const py::handle dO, + const at::ScalarType fake_dtype, const DType dqkv_type, const std::vector Aux_CTX_Tensors, const std::optional cu_seqlens_q_padded, const std::optional cu_seqlens_kv_padded, py::handle s_quantizer, - py::handle dp_quantizer, py::handle dqkv_quantizer) { + py::handle dp_quantizer, py::handle dqkv_quantizer, bool cuda_graph) { auto none = py::none(); - TensorWrapper te_Q, te_K, te_V, te_O, te_dO, te_S, te_dP, te_dQ, te_dK, te_dV; + + // create QKV, O, dO tensor wrappers + TensorWrapper te_Q, te_K, te_V, te_O, te_dO; te_Q = makeTransformerEngineTensor(Q, none); te_K = makeTransformerEngineTensor(K, none); te_V = makeTransformerEngineTensor(V, none); te_O = makeTransformerEngineTensor(O, none); te_dO = makeTransformerEngineTensor(dO, none); - // qkv type from the te_Q - std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); - const DType qkv_type = te_Q.dtype(); - const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); - - py::object s_python, dp_python; - std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); - std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); - auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); - NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, - std::nullopt, std::nullopt); - } else { - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); - } + // create S and dP tensors + TensorWrapper te_S, te_dP; + py::object py_S, py_dP; + std::tie(te_S, py_S) = quantizer_helper(s_quantizer, {0}, DType::kFloat32, false, std::nullopt); + std::tie(te_dP, py_dP) = + quantizer_helper(dp_quantizer, {0}, DType::kFloat32, false, std::nullopt); + // create dQ, dK, dV tensors + TensorWrapper te_dQ, te_dK, te_dV; + py::object py_dQ, py_dK, py_dV; + std::unique_ptr dQKV_quantizer = convert_quantizer(dqkv_quantizer); std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); std::vector v_shape = convertShape(te_V.shape()); auto h_q = q_shape[q_shape.size() - 2]; auto h_kv = k_shape[k_shape.size() - 2]; auto d_qk = q_shape[q_shape.size() - 1]; - auto d_v = v_shape[v_shape.size() - 1]; - auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); - std::vector o_shape{q_shape.begin(), q_shape.end()}; - o_shape[o_shape.size() - 1] = d_v; + const DType fake_dtype_te = GetTransformerEngineDType(fake_dtype); at::Tensor dQ, dK, dV, dQKV, dKV; - py::object py_dQ, py_dK, py_dV; NVTE_QKV_Layout_Group layout_group = nvte_get_qkv_layout_group(qkv_layout); std::vector tmp_shape; + auto options = torch::TensorOptions().dtype(GetATenDType(dqkv_type)).device(torch::kCUDA); + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { + options = options.dtype(torch::kUInt8); + } + if (detail::IsFloat8CurrentScalingQuantizers(dqkv_quantizer.ptr())) { + options = options.dtype(fake_dtype); + } switch (layout_group) { case NVTE_QKV_Layout_Group::NVTE_3HD: @@ -398,39 +424,27 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { - auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); - std::tie(te_dQ, py_dQ) = - fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); - std::tie(te_dK, py_dK) = - fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); - std::tie(te_dV, py_dV) = - fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); - } else { - auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); - NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); - std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); - } + + std::tie(te_dQ, py_dQ) = quantizer_helper(dqkv_quantizer, q_shape, fake_dtype_te, true, dQ); + std::tie(te_dK, py_dK) = quantizer_helper(dqkv_quantizer, k_shape, fake_dtype_te, true, dK); + std::tie(te_dV, py_dV) = quantizer_helper(dqkv_quantizer, v_shape, fake_dtype_te, true, dV); // construct NVTE tensors - if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + if (dqkv_type == DType::kFloat8E4M3 || dqkv_type == DType::kFloat8E5M2) { // FP8 - if (set_zero && ((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && - dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous() && - (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { - mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); - } else { - dQ.fill_(0); - dK.fill_(0); - dV.fill_(0); + if (set_zero && (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD)) { + if (((h_q * d_qk) % block_size == 0) && ((h_kv * d_qk) % block_size == 0) && + dQ.is_contiguous() && dK.is_contiguous() && dV.is_contiguous()) { + mha_fill(te_dQ, cu_seqlens_q.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dK, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + mha_fill(te_dV, cu_seqlens_kv.index({torch::indexing::Slice(-1, torch::indexing::None)})); + } else { + dQ.fill_(0); + dK.fill_(0); + dV.fill_(0); + } } - - } else if (qkv_type == DType::kBFloat16 || qkv_type == DType::kFloat16) { + } else if (dqkv_type == DType::kBFloat16 || dqkv_type == DType::kFloat16) { if (nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format::NVTE_THD) { dQ.fill_(0); dK.fill_(0); @@ -499,18 +513,28 @@ std::vector fused_attn_bwd( } } + // create dSoftmaxOffset in the same shape as SoftmaxOffset + at::Tensor dSoftmaxOffset; + TensorWrapper te_dSoftmaxOffset; + if (softmax_type != NVTE_VANILLA_SOFTMAX) { + options = torch::TensorOptions().dtype(at::kFloat).device(torch::kCUDA); + dSoftmaxOffset = torch::empty({1, static_cast(h_q), 1, 1}, options); + te_dSoftmaxOffset = makeTransformerEngineTensor(dSoftmaxOffset); + } + // create workspace TensorWrapper workspace; // populate tensors with appropriate shapes and dtypes NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // allocate memory for workspace @@ -520,19 +544,20 @@ std::vector fused_attn_bwd( // execute kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_fused_attn_bwd( - te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), te_S.data(), te_dP.data(), - &nvte_aux_tensor_pack, te_dQ.data(), te_dK.data(), te_dV.data(), te_dBias.data(), - te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), te_cu_seqlens_q_padded.data(), - te_cu_seqlens_kv_padded.data(), max_seqlen_q, max_seqlen_kv, attn_scale, p_dropout, - qkv_layout, bias_type, attn_mask_type, window_size[0], window_size[1], deterministic, - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_fused_attn_bwd(te_Q.data(), te_K.data(), te_V.data(), te_O.data(), te_dO.data(), + te_S.data(), te_dP.data(), &nvte_aux_tensor_pack, te_dQ.data(), + te_dK.data(), te_dV.data(), te_dBias.data(), te_dSoftmaxOffset.data(), + te_cu_seqlens_q.data(), te_cu_seqlens_kv.data(), + te_cu_seqlens_q_padded.data(), te_cu_seqlens_kv_padded.data(), max_seqlen_q, + max_seqlen_kv, attn_scale, p_dropout, qkv_layout, bias_type, attn_mask_type, + softmax_type, window_size[0], window_size[1], deterministic, cuda_graph, + workspace.data(), at::cuda::getCurrentCUDAStream()); }); // destroy tensor wrappers nvte_tensor_pack_destroy(&nvte_aux_tensor_pack); - return {py_dQ, py_dK, py_dV, py::cast(dBias)}; + return {py_dQ, py_dK, py_dV, py::cast(dBias), py::cast(dSoftmaxOffset)}; } at::Tensor fa_prepare_fwd(at::Tensor qkvi) { @@ -598,7 +623,6 @@ at::Tensor thd_read_half_tensor(const at::Tensor &tensor, const at::Tensor &cu_s // Shapes of kv and dkv are [2, t, h, d], so the dimension of "t" is 1 int seq_dim = tensor.dim() == 3 ? 0 : 1; - int batch = cu_seqlens.size(0) - 1; int num_heads = tensor.size(seq_dim + 1); int dim_per_head = tensor.size(seq_dim + 2); int hidden_size_in_bytes = num_heads * dim_per_head * c10::elementSize(tensor.scalar_type()); @@ -762,8 +786,6 @@ at::Tensor thd_get_partitioned_indices(const at::Tensor &cu_seqlens, int total_t NVTE_CHECK(world_size > 0); NVTE_CHECK(total_tokens > 0 && total_tokens % (world_size * 2) == 0); - int batch = cu_seqlens.size(0) - 1; - std::vector shape = {total_tokens / world_size}; at::Tensor output = at::empty(shape, at::CUDA(at::ScalarType::Int)); @@ -801,7 +823,6 @@ at::Tensor convert_thd_to_bshd(at::Tensor tensor, at::Tensor cu_seqlens, int b, **************************************************************************************************/ at::Tensor convert_bshd_to_thd(at::Tensor tensor, at::Tensor cu_seqlens, int t) { - int max_seq_len = tensor.size(1); int h = tensor.size(2); int d = tensor.size(3); std::vector shape = {t, h, d}; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index f65614d07..e8a735966 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -56,10 +56,25 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; } - // Unfused impl if quantizer is not supported - const bool with_fused_dbias_quantize_kernel = - detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); - if (!with_fused_dbias_quantize_kernel) { + // Check if fused kernel is supported + bool with_fused_kernel = false; + if (detail::IsFloat8Quantizers(quantizer.ptr())) { + auto prop = at::cuda::getCurrentDeviceProperties(); + const size_t sm_arch = 10 * prop->major + prop->minor; + if (sm_arch >= 100) { + // Fused kernel for dbias + FP8 cast on SM arch 10.0+ + with_fused_kernel = true; + } else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) { + // Fused kernel for dbias + FP8 cast + FP8 transpose + with_fused_kernel = true; + } + } else if (detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Fused kernel for dbias + MXFP8 quantize + with_fused_kernel = true; + } + + // Apply unfused impl if fused kernel is not supported + if (!with_fused_kernel) { at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; @@ -124,14 +139,32 @@ std::vector dact_dbias( } // Choose implementation - enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; + enum class Impl { + UNFUSED, + FUSED_DACT_DBIAS_QUANTIZE, + FUSED_DACT_AMAX_FP8, + FUSED_DACT_AMAX_NVFP4 + }; Impl impl = Impl::UNFUSED; if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || detail::IsMXFP8Quantizers(quantizer_py.ptr())) { impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - impl = Impl::FUSED_DACT_AMAX; + impl = Impl::FUSED_DACT_AMAX_FP8; +#ifdef USE_ROCM + } +#else + } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else { + impl = Impl::FUSED_DACT_AMAX_NVFP4; + } } +#endif // Perform compute auto stream = at::cuda::getCurrentCUDAStream(); @@ -174,22 +207,42 @@ std::vector dact_dbias( }); break; } - case Impl::FUSED_DACT_AMAX: - // Fused dact-amax kernel, unfused dbias and quantize + case Impl::FUSED_DACT_AMAX_FP8: + // Fused dact-amax kernel, unfused dbias and FP8 quantize { - auto *quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); - NVTE_CHECK(quantizer_cpp_cs != nullptr, + auto *fp8_quantizer_cpp = + dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Invalid quantizer for fused dact-amax kernel impl"); auto [temp_nvte, temp_py] = - quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); + break; + } +#ifndef USE_ROCM + case Impl::FUSED_DACT_AMAX_NVFP4: + // Fused dact-amax kernel, unfused dbias and NVFP4 quantize + { + auto *nvfp4_quantizer_cpp = + static_cast(quantizer_cpp.get()); // Already checked cast is valid + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, + "Invalid quantizer for fused dact-amax kernel impl"); + auto [temp_nvte, temp_py] = nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax( + grad_input_nvte, grad_output_dtype); NVTE_SCOPED_GIL_RELEASE({ dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); }); const auto temp_torch = temp_py.cast(); at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); - quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); + nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); break; } +#endif default: NVTE_ERROR("Invalid implementation"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index c940181b0..8fc4e1e97 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -39,7 +39,18 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob // Convert input tensor to C++ object auto input_contiguous = tensor.contiguous(); - const auto input_cpp = makeTransformerEngineTensor(input_contiguous); + auto input_cpp = makeTransformerEngineTensor(input_contiguous); + + // Set amax if use_existing_amax = true (only valid for CS) + bool use_existing_amax = false; + if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + use_existing_amax = quantizer.attr("use_existing_amax").cast(); + if (use_existing_amax) { + const at::Tensor &amax = quantizer.attr("amax").cast(); + input_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + } // Initialize output tensor TensorWrapper output_cpp; @@ -59,7 +70,12 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob } // Perform quantization - quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + if (use_existing_amax) { + auto *quantizer_cs = dynamic_cast(quantizer_cpp.get()); + quantizer_cs->quantize_with_amax(input_cpp, output_cpp, noop_flag_cpp); + } else { + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); + } return output_py; } @@ -300,7 +316,7 @@ std::tuple, std::vector> bulk_allocate_fp // Construct FP8 block-wise tensors py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); for (size_t i = 0; i < num_tensors; ++i) { // Create tensor objects with proper reference counting py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); @@ -447,7 +463,7 @@ std::tuple, std::vector> bulk_allocate_mx } // Construct mxfp8 tensors - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); for (size_t i = 0; i < num_tensors; ++i) { // Create tensor objects with proper reference counting py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); @@ -477,6 +493,209 @@ std::tuple, std::vector> bulk_allocate_mx return retval; } +#ifndef USE_ROCM +// allocate fp4 data, fp8 scalings, and amax values +// layout: [fp4_data0, ..., fp4_dataN, fp8_scaling0, ..., fp8_scalingN, amax0, ..., amaxN] +// amax buffer will be zeroed out by later amax kernels, so we can use empty to allocate +std::tuple, std::vector> bulk_allocate_nvfp4_tensors( + std::vector> &shape_list, std::vector &quantizer_py_list, + std::vector &quantizer_cpp_list) { + init_extension(); + std::tuple, std::vector> retval; + auto &tensor_py_list = std::get<0>(retval); + auto &tensor_cpp_list = std::get<1>(retval); + + // Number of tensors + const size_t num_tensors = shape_list.size(); + if (num_tensors == 0) { + return retval; + } + + // Quantization parameters + const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); + const auto fp4_dtype = quantizer_cpp_list[0]->dtype; + constexpr size_t scale_elem_size = 1; + + // Helper function to construct tensor view + // Note: Deleter holds a shared_ptr for the buffer, so the buffer + // will survive until all views are deleted. + auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, + size_t offset, at::ScalarType dtype) -> at::Tensor { + std::vector shape_int64(shape.begin(), shape.end()); + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { + return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); + } + return at::from_blob( + buffer->data_ptr() + offset, shape_int64, + [buffer](void *) {}, // deleter holds shared_ptr + at::device(at::kCUDA).dtype(dtype)); + }; + + // Lambda function for converting std::vector shape to NVFP4 shape (last dim divided by 2) + auto to_fp4_shape = [](const std::vector &shape) { + std::vector fp4_shape(shape.begin(), shape.end()); + if (!fp4_shape.empty()) { + fp4_shape.back() /= 2; + } + return fp4_shape; + }; + + // Allocate row-wise data + std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; + std::vector> rowwise_data_shapes, rowwise_scale_shapes; + if (rowwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_shapes.emplace_back(shape_list[i]); + rowwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], false)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets, amax_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + // Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes). + // Integer arithmetic: ceil(product / 2) == (product + 1) / 2. + buffer_size += (product(rowwise_data_shapes[i]) + 1) / 2; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(rowwise_scale_shapes[i]) * scale_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + amax_offsets.push_back(buffer_size); + // amax is scalar in fp32, 4 bytes each + buffer_size += 4; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + rowwise_data_list.emplace_back(make_torch_view(buffer, to_fp4_shape(rowwise_data_shapes[i]), + data_offsets[i], torch::kUInt8)); + rowwise_scale_list.emplace_back( + make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + amax_rowwise_list.emplace_back( + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + } + } + + // Allocate column-wise data + std::vector columnwise_data_list, columnwise_scale_list, amax_columnwise_list; + std::vector> columnwise_data_shapes, columnwise_scale_shapes; + if (columnwise_usage) { + // Tensor sizes + for (size_t i = 0; i < num_tensors; ++i) { + // push the transposed shape into NVFP4 columnwise shape + // NVFP4 on SM100 is TN only + columnwise_data_shapes.emplace_back(); + auto &shape = columnwise_data_shapes.back(); + shape.push_back(shape_list[i].back()); + for (size_t j = 0; j < shape_list[i].size() - 1; ++j) { + shape.push_back(shape_list[i][j]); + } + columnwise_scale_shapes.emplace_back( + quantizer_cpp_list[i]->get_scale_shape(shape_list[i], true)); + } + + // Offsets in full buffer + size_t buffer_size = 0; + std::vector data_offsets, scale_offsets, amax_offsets; + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 256); // align to 256B + data_offsets.push_back(buffer_size); + // Store ceil(product/2) bytes for fp4 (since each element is 4 bits = 0.5 bytes). + // Integer arithmetic: ceil(product / 2) == (product + 1) / 2. + buffer_size += (product(columnwise_data_shapes[i]) + 1) / 2; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_offsets.push_back(buffer_size); + buffer_size += product(columnwise_scale_shapes[i]) * scale_elem_size; + } + for (size_t i = 0; i < num_tensors; ++i) { + buffer_size = roundup(buffer_size, 16); // align to 16B + amax_offsets.push_back(buffer_size); + // amax is scalar in fp32, 4 bytes each + buffer_size += 4; + } + + // Allocate full buffer + auto buffer = std::make_shared( + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + + // Construct tensor views + for (size_t i = 0; i < num_tensors; ++i) { + columnwise_data_list.emplace_back(make_torch_view( + buffer, to_fp4_shape(columnwise_data_shapes[i]), data_offsets[i], torch::kUInt8)); + columnwise_scale_list.emplace_back( + make_torch_view(buffer, columnwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + amax_columnwise_list.emplace_back( + make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kUInt8)); + } + } + + // Construct nvfp4 tensors + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); + for (size_t i = 0; i < num_tensors; ++i) { + // Create tensor objects with proper reference counting + py::object rowwise_data = rowwise_usage ? py::cast(rowwise_data_list[i]) : py::none(); + py::object rowwise_scale = rowwise_usage ? py::cast(rowwise_scale_list[i]) : py::none(); + py::object columnwise_data = + (columnwise_usage ? py::cast(columnwise_data_list[i]) : py::none()); + py::object columnwise_scale = + (columnwise_usage ? py::cast(columnwise_scale_list[i]) : py::none()); + py::object amax_rowwise = rowwise_usage ? py::cast(amax_rowwise_list[i]) : py::none(); + py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); + + // Construct Python tensor + tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, amax_rowwise, amax_columnwise, + fp4_dtype, quantizer_py_list[i])); + + // Construct C++ tensor + // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, + // then set the amax and amax_columnwise values. + { + auto tensor_wrapper = makeTransformerEngineTensor( + rowwise_usage ? rowwise_data_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_data_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_data_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_data_shapes[i] : std::vector{}, fp4_dtype, + /*amax_ptr=*/nullptr, + /*scale_ptr=*/nullptr, rowwise_usage ? rowwise_scale_list[i].data_ptr() : nullptr, + columnwise_usage ? columnwise_scale_list[i].data_ptr() : nullptr, + rowwise_usage ? rowwise_scale_shapes[i] : std::vector{}, + columnwise_usage ? columnwise_scale_shapes[i] : std::vector{}, scaling_mode); + + // Set the amax rowwise and amax columnwise if available + if (rowwise_usage) { + tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (columnwise_usage) { + tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, + std::vector{1}); + } + tensor_cpp_list.emplace_back(std::move(tensor_wrapper)); + } + } + + return retval; +} +#endif // #ifndef USE_ROCM + } // namespace std::vector split_quantize(const at::Tensor &tensor, @@ -535,7 +754,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool use_fused_bulk_alloc = true; for (size_t i = 0; i < quantizer_list.size(); i++) { if (!detail::IsFloat8BlockwiseQuantizers(quantizer_list[i].ptr()) && - !detail::IsMXFP8Quantizers(quantizer_list[i].ptr())) { + !detail::IsMXFP8Quantizers(quantizer_list[i].ptr()) && + !detail::IsNVFP4Quantizers(quantizer_list[i].ptr())) { use_fused_bulk_alloc = false; break; } @@ -556,6 +776,7 @@ std::vector split_quantize(const at::Tensor &tensor, // TODO(zhongbo): make a better api to make this part less hacky bool is_fp8_blockwise = detail::IsFloat8BlockwiseQuantizers(quantizer_list[0].ptr()); bool is_mxfp8 = detail::IsMXFP8Quantizers(quantizer_list[0].ptr()); + bool is_nvfp4 = detail::IsNVFP4Quantizers(quantizer_list[0].ptr()); if (is_fp8_blockwise) { // FP8 block-scaling: construct output tensors with bulk allocations std::vector blockwise_quantizers; @@ -572,6 +793,16 @@ std::vector split_quantize(const at::Tensor &tensor, } std::tie(output_py_list, output_cpp_list) = bulk_allocate_mxfp8_tensors(split_shapes, quantizer_list, mxfp8_quantizers); +#ifndef USE_ROCM + } else if (is_nvfp4) { + // NVFP4: construct output tensors with bulk allocations + std::vector nvfp4_quantizers; + for (auto &quantizer : quantizer_cpp_list) { + nvfp4_quantizers.push_back(static_cast(quantizer.get())); + } + std::tie(output_py_list, output_cpp_list) = + bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); +#endif } else { NVTE_CHECK(false, "Expected either FP8 block-scaling or MXFP8 quantizer"); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index b637d49c7..4a438d366 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -106,6 +106,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const bool low_precision = detail::is_low_precision(A_tensor.dtype()) || detail::is_low_precision(B_tensor.dtype()); + const bool fp8_block_scaling = A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + A_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_1D || + B_tensor.scaling_mode() == NVTE_BLOCK_SCALING_2D; // Check tensor dimensions const auto& A_shape = A_tensor.shape(); @@ -215,6 +219,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); + // Construct GEMM config + transformer_engine::MatmulConfigWrapper config; + if (grad) { + config.set_dbias_tensor(bias_tensor.data()); + config.set_with_dgelu_epilogue(gelu); + } else { + config.set_bias_tensor(bias_tensor.data()); + config.set_with_gelu_epilogue(gelu); + } + config.set_epilogue_aux_tensor(te_pre_gelu_out.data()); + config.set_use_split_accumulator(use_split_accumulator); + config.set_sm_count(num_math_sms); + #ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; @@ -226,6 +243,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); + + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (fp8_block_scaling && transformer_engine::cuda::sm_arch() >= 100) { + // Convert tensors to mxfp8 and swizzle their scaling factors + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(A_tensor, transa))); + swizzled_scale_inverses_list.emplace_back( + std::move(convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb))); + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } #endif if (comm_overlap) { @@ -286,10 +316,9 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), - bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, - te_workspace.data(), alpha, *beta, use_split_accumulator, - num_math_sms, main_stream); + nvte_cublas_gemm_v2(transa, transb, &alpha, A_tensor.data(), B_tensor.data(), &beta.value(), + out_tensor.data(), out_tensor.data(), te_workspace.data(), config, + main_stream); }); } } else { @@ -377,15 +406,6 @@ std::optional> te_general_grouped_gemm( std::vector bias, DType bias_type, bool single_output, std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { - std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, - te_pre_gelu_out_vector, te_workspace_vector; - std::vector te_A_wrappers, te_B_wrappers, wrappers; - std::vector D_vectors; - - auto none = py::none(); - - std::vector single_output_begins; - std::vector single_output_ends; if (single_output && D == std::nullopt) { NVTE_ERROR("not implemented, D should be allocated for single output case."); } @@ -395,6 +415,10 @@ std::optional> te_general_grouped_gemm( output_data_ptr = (*D)[0].data_ptr(); } + const auto none = py::none(); + std::vector te_A_wrappers, te_B_wrappers, te_D_wrappers, te_bias_wrappers, + te_pre_gelu_out_wrappers; + std::vector D_vectors; for (size_t i = 0; i < A.size(); i++) { auto te_A = makeTransformerEngineTensor(A[i], none); auto te_B = makeTransformerEngineTensor(B[i], none); @@ -460,31 +484,74 @@ std::optional> te_general_grouped_gemm( te_pre_gelu_out = makeTransformerEngineTensor(get_data_ptr(pre_gelu_out[i]), gelu_shape, gelu_type); - te_A_vector.emplace_back(te_A.data()); - te_B_vector.emplace_back(te_B.data()); - te_D_vector.emplace_back(te_D.data()); - te_bias_vector.emplace_back(te_bias.data()); - te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - te_A_wrappers.emplace_back(std::move(te_A)); te_B_wrappers.emplace_back(std::move(te_B)); - wrappers.emplace_back(std::move(te_D)); - wrappers.emplace_back(std::move(te_bias)); - wrappers.emplace_back(std::move(te_pre_gelu_out)); + te_D_wrappers.emplace_back(std::move(te_D)); + te_bias_wrappers.emplace_back(std::move(te_bias)); + te_pre_gelu_out_wrappers.emplace_back(std::move(te_pre_gelu_out)); } #ifndef USE_ROCM + // Keep the swizzled scaling factor tensors alive during the GEMM. + std::vector> swizzled_scale_inverses_list; + // Optionally swizzle the scaling factors - // Keep the swizzled scaling factor tensors alive during the GEMMs. - auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); - auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa)); + swizzled_scale_inverses_list.emplace_back( + multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb)); + + // Emulate the FP8 block scaling recipe with MXFP8 on Blackwell and newer + // as it is not natively supported by cublasLt + if (transformer_engine::cuda::sm_arch() >= 100) { + // Check if is using FP8 block scaling + bool exists_tensor_using_fp8_block_scaling = false; + bool exists_tensor_not_using_fp8_block_scaling = false; + for (const auto& tensor_wrappers : {&te_A_wrappers, &te_B_wrappers}) { + for (const TensorWrapper& tensor : *tensor_wrappers) { + const NVTEScalingMode scaling_mode = tensor.scaling_mode(); + if (scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D) + exists_tensor_using_fp8_block_scaling = true; + else + exists_tensor_not_using_fp8_block_scaling = true; + } + } + if (exists_tensor_using_fp8_block_scaling) { + NVTE_CHECK(!exists_tensor_not_using_fp8_block_scaling, + "Either all tensors or no tensor must be FP8 block scaling tensors"); + // Convert tensors to mxfp8 and swizzle their scaling factors + for (TensorWrapper& A_tensor : te_A_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(A_tensor, transa)); + } + for (TensorWrapper& B_tensor : te_B_wrappers) { + swizzled_scale_inverses_list.emplace_back( + convert_block_scaling_to_mxfp8_tensor(B_tensor, !transb)); + } + // Use TN GEMM to avoid having to transpose data. + transa = true; + transb = false; + } + } #endif + std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, + te_pre_gelu_out_vector; + for (size_t i = 0; i < te_A_wrappers.size(); i++) { + te_A_vector.emplace_back(te_A_wrappers[i].data()); + te_B_vector.emplace_back(te_B_wrappers[i].data()); + te_D_vector.emplace_back(te_D_wrappers[i].data()); + te_bias_vector.emplace_back(te_bias_wrappers[i].data()); + te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out_wrappers[i].data()); + } + + std::vector te_workspace_vector; + std::vector te_workspace_wrappers; for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); - wrappers.emplace_back(std::move(wsp)); + te_workspace_wrappers.emplace_back(std::move(wsp)); } // For now, we only have multi-stream cublas backend. diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 728d39cbd..839bb694a 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -68,67 +68,108 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); - TensorWrapper bias_cu; + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); + TensorWrapper bias_nvte; if (bias.has_value()) { - bias_cu = makeTransformerEngineTensor(*bias); + bias_nvte = makeTransformerEngineTensor(*bias); } // Tensor dimensions - const size_t N = static_cast(input_cu.size(0)); - const size_t H = static_cast(input_cu.size(1)); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - at::Tensor mu = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - at::Tensor rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - TensorWrapper mu_cu = makeTransformerEngineTensor(mu); - TensorWrapper rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor mu_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper mu_nvte = makeTransformerEngineTensor(mu_py); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; +#ifdef USE_ROCM } - TensorWrapper unquantized_out_cu; +#else + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; + } + } + #endif + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#endif + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -140,24 +181,33 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_layernorm_fwd(input_cu.data(), weight_cu.data(), bias_cu.data(), eps, kernel_out_cu.data(), - mu_cu.data(), rsigma_cu.data(), workspace.data(), + nvte_layernorm_fwd(input_nvte.data(), weight_nvte.data(), bias_nvte.data(), eps, + kernel_out_nvte->data(), mu_nvte.data(), rsigma_nvte.data(), + workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#endif + default: { } } - return {out, py::cast(mu), py::cast(rsigma)}; + return {out, py::cast(mu_py), py::cast(rsigma_py)}; } std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, @@ -256,61 +306,101 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Input and param tensors auto none = py::none(); - const TensorWrapper &input_cu = makeTransformerEngineTensor(input, none); - const TensorWrapper &weight_cu = makeTransformerEngineTensor(weight, none); + const TensorWrapper &input_nvte = makeTransformerEngineTensor(input, none); + const TensorWrapper &weight_nvte = makeTransformerEngineTensor(weight, none); // Tensor dimensions - const size_t N = static_cast(input_cu.shape().data[0]); - const size_t H = static_cast(input_cu.shape().data[1]); - const std::vector size = {N, H}; + const auto shape = nvte_shape_to_vector(input_nvte.shape()); + const auto outer_size = product(shape) / shape.back(); + const auto inner_size = shape.back(); // Tensors to save for backward pass - auto rsigma = at::empty({static_cast(N)}, at::CUDA(at::kFloat)); - auto rsigma_cu = makeTransformerEngineTensor(rsigma); + at::Tensor rsigma_py = at::empty({static_cast(outer_size)}, at::CUDA(at::kFloat)); + TensorWrapper rsigma_nvte = makeTransformerEngineTensor(rsigma_py); // Output tensor - std::unique_ptr my_quantizer = convert_quantizer(quantizer); - TensorWrapper out_cu; + auto quantizer_cpp = convert_quantizer(quantizer); + TensorWrapper out_nvte; if (out.is_none()) { - std::tie(out_cu, out) = my_quantizer->create_tensor(size, out_dtype); + std::tie(out_nvte, out) = quantizer_cpp->create_tensor(shape, out_dtype); } else { - out_cu = makeTransformerEngineTensor(out, quantizer); + out_nvte = makeTransformerEngineTensor(out, quantizer); } - // Determine whether to avoid fused kernel - bool force_unfused_kernel = true; - if (quantizer.is_none()) { - // No need for separate quantization step if output is unquantized - force_unfused_kernel = false; - } else if (IsFloat8Quantizers(quantizer.ptr())) { - // Always used fused kernel for FP8 delayed scaling - force_unfused_kernel = false; + // Choose implementation + enum class Impl { + // Compute norm in high precision, then quantize + UNFUSED, + // Compute norm directly + FULLY_FUSED, + // Compute norm and amax in high precision, then quantize to FP8 + FUSED_NORM_AMAX_FP8, + // Compute norm and amax in high precision, then quantize to NVFP4 + FUSED_NORM_AMAX_NVFP4 + }; + Impl impl = Impl::UNFUSED; + if (quantizer.is_none() || IsFloat8Quantizers(quantizer.ptr())) { + impl = Impl::FULLY_FUSED; } else if (IsMXFP8Quantizers(quantizer.ptr())) { - if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - // cuDNN MXFP8 kernel requires full tile - force_unfused_kernel = N % 128 != 0 || H % 128 != 0; + if (transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN") && outer_size % 128 == 0 && + inner_size % 128 == 0) { + // cuDNN MXFP8 kernel requires full 128x128 tiles + impl = Impl::FULLY_FUSED; } + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + auto fp8_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); + impl = Impl::FUSED_NORM_AMAX_FP8; +#ifdef USE_ROCM } - TensorWrapper unquantized_out_cu; +#else + } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { + auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); + if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { + // Post-RHT amax is handled within NVFP4 quantizer + impl = Impl::UNFUSED; + } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { + // TE kernel supports amax output + impl = Impl::FUSED_NORM_AMAX_NVFP4; + } + } +#endif + + // Construct unquantized output tensor if needed + TensorWrapper unquantized_out_nvte; py::object unquantized_out; - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - std::tie(unquantized_out_cu, unquantized_out) = - my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); - } else { + TensorWrapper *kernel_out_nvte = &out_nvte; + switch (impl) { + case Impl::UNFUSED: { NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + std::tie(unquantized_out_nvte, unquantized_out) = q.create_tensor(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + fp8_quantizer_cpp->create_unquantized_tensor_with_amax(shape, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + std::tie(unquantized_out_nvte, unquantized_out) = + nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, out_dtype); + kernel_out_nvte = &unquantized_out_nvte; + } break; +#endif + default: { } } - TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; // Query workspace size TensorWrapper workspace; NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); @@ -322,24 +412,32 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Launch kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_rmsnorm_fwd(input_cu.data(), weight_cu.data(), eps, kernel_out_cu.data(), rsigma_cu.data(), - workspace.data(), + nvte_rmsnorm_fwd(input_nvte.data(), weight_nvte.data(), eps, kernel_out_nvte->data(), + rsigma_nvte.data(), workspace.data(), at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, zero_centered_gamma, at::cuda::getCurrentCUDAStream()); }); - // Quantize output if using unfused kernel - if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && - !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { - auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); - my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); - } else { - my_quantizer->quantize(unquantized_out_cu, out_cu); + // Quantize output if needed + switch (impl) { + case Impl::UNFUSED: { + quantizer_cpp->quantize(unquantized_out_nvte, out_nvte); + } break; + case Impl::FUSED_NORM_AMAX_FP8: { + auto fp8_quantizer_cpp = static_cast(quantizer_cpp.get()); + fp8_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#ifndef USE_ROCM + case Impl::FUSED_NORM_AMAX_NVFP4: { + auto nvfp4_quantizer_cpp = static_cast(quantizer_cpp.get()); + nvfp4_quantizer_cpp->quantize_with_amax(unquantized_out_nvte, out_nvte); + } break; +#endif + default: { } } - return {out, py::none(), py::cast(rsigma)}; + return {out, py::none(), py::cast(rsigma_py)}; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 55b1d179e..268708d5e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -25,15 +25,18 @@ namespace transformer_engine::pytorch { PyTypeObject *Float8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *Float8TensorBasePythonClass = nullptr; +PyTypeObject *Float8TensorStoragePythonClass = nullptr; PyTypeObject *Float8QuantizerClass = nullptr; PyTypeObject *Float8CurrentScalingQuantizerClass = nullptr; PyTypeObject *MXFP8TensorPythonClass = nullptr; /// TODO Remove -PyTypeObject *MXFP8TensorBasePythonClass = nullptr; +PyTypeObject *MXFP8TensorStoragePythonClass = nullptr; PyTypeObject *MXFP8QuantizerClass = nullptr; PyTypeObject *Float8BlockwiseQTensorPythonClass = nullptr; -PyTypeObject *Float8BlockwiseQTensorBasePythonClass = nullptr; +PyTypeObject *Float8BlockwiseQTensorStoragePythonClass = nullptr; PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; +PyTypeObject *NVFP4TensorPythonClass = nullptr; +PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; +PyTypeObject *NVFP4QuantizerClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -45,9 +48,9 @@ void init_float8_extension() { Float8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Tensor")); auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.float8_tensor_base"); - Float8TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.float8_tensor_storage"); + Float8TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8TensorStorage")); NVTE_CHECK(Float8TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch Float8 extension."); } @@ -60,38 +63,54 @@ void init_mxfp8_extension() { MXFP8TensorPythonClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Tensor")); auto fp8_base_module = - py::module_::import("transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base"); - MXFP8TensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorBase")); + py::module_::import("transformer_engine.pytorch.tensor.storage.mxfp8_tensor_storage"); + MXFP8TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "MXFP8TensorStorage")); NVTE_CHECK(MXFP8TensorPythonClass != nullptr, "Internal error: could not initialize pyTorch MXFP8 extension."); } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorBasePythonClass) return; + if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( - "transformer_engine.pytorch.tensor._internal.float8_blockwise_tensor_base"); + "transformer_engine.pytorch.tensor.storage.float8_blockwise_tensor_storage"); Float8BlockwiseQuantizerClass = reinterpret_cast( PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockQuantizer")); - Float8BlockwiseQTensorBasePythonClass = reinterpret_cast( - PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorBase")); + Float8BlockwiseQTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(fp8_base_module.ptr(), "Float8BlockwiseQTensorStorage")); Float8BlockwiseQTensorPythonClass = reinterpret_cast( PyObject_GetAttrString(fp8_module.ptr(), "Float8BlockwiseQTensor")); NVTE_CHECK(Float8BlockwiseQuantizerClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); - NVTE_CHECK(Float8BlockwiseQTensorBasePythonClass != nullptr, + NVTE_CHECK(Float8BlockwiseQTensorStoragePythonClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); NVTE_CHECK(Float8BlockwiseQTensorPythonClass != nullptr, "Internal error: could not initialize pyTorch float8blockwise extension."); } +void init_nvfp4_extensions() { + if (NVFP4TensorPythonClass) return; + auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); + NVFP4QuantizerClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); + NVFP4TensorPythonClass = + reinterpret_cast(PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Tensor")); + auto nvfp4_base_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage"); + NVFP4TensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(nvfp4_base_module.ptr(), "NVFP4TensorStorage")); + NVTE_CHECK(NVFP4TensorPythonClass != nullptr, + "Internal error: could not initialize pyTorch NVFP4 extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); + init_nvfp4_extensions(); } } // namespace transformer_engine::pytorch @@ -138,6 +157,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer")); m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); + m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, + "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), + py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); @@ -161,6 +183,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("fwd_input"), py::arg("quantizer")); m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu, + "Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"), + py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 10a889b56..577a938f2 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -22,12 +22,13 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor"); TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element"); + auto* amax_ptr = amax.data_ptr(); TensorWrapper fake_te_output( nullptr, te_input.shape(), DType::kFloat8E4M3, // It doesn't matter because we only compute amax. - amax.data_ptr()); + amax_ptr); -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM at::Tensor ws = allocate_amax_workspace(te_input); TensorWrapper tw = makeTransformerEngineTensor(ws); nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 9fd1ae4de..a84641364 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -31,22 +33,21 @@ namespace transformer_engine::pytorch { } while (false); extern PyTypeObject *Float8TensorPythonClass; -extern PyTypeObject *Float8TensorBasePythonClass; +extern PyTypeObject *Float8TensorStoragePythonClass; extern PyTypeObject *Float8QuantizerClass; extern PyTypeObject *Float8CurrentScalingQuantizerClass; extern PyTypeObject *MXFP8TensorPythonClass; -extern PyTypeObject *MXFP8TensorBasePythonClass; +extern PyTypeObject *MXFP8TensorStoragePythonClass; extern PyTypeObject *MXFP8QuantizerClass; extern PyTypeObject *Float8BlockwiseQTensorPythonClass; -extern PyTypeObject *Float8BlockwiseQTensorBasePythonClass; +extern PyTypeObject *Float8BlockwiseQTensorStoragePythonClass; extern PyTypeObject *Float8BlockwiseQuantizerClass; +extern PyTypeObject *NVFP4TensorPythonClass; +extern PyTypeObject *NVFP4TensorStoragePythonClass; +extern PyTypeObject *NVFP4QuantizerClass; void init_extension(); -void init_float8_extension(); - -void init_mxfp8_extension(); - namespace detail { inline bool IsFloat8Quantizers(PyObject *obj) { return Py_TYPE(obj) == Float8QuantizerClass; } @@ -56,22 +57,28 @@ inline bool IsFloat8CurrentScalingQuantizers(PyObject *obj) { } inline bool IsFloat8Tensor(PyObject *obj) { - return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorBasePythonClass; + return Py_TYPE(obj) == Float8TensorPythonClass || Py_TYPE(obj) == Float8TensorStoragePythonClass; } inline bool IsMXFP8Quantizers(PyObject *obj) { return Py_TYPE(obj) == MXFP8QuantizerClass; } inline bool IsMXFP8Tensor(PyObject *obj) { - return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorBasePythonClass; + return Py_TYPE(obj) == MXFP8TensorPythonClass || Py_TYPE(obj) == MXFP8TensorStoragePythonClass; } inline bool IsFloat8BlockwiseQuantizers(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQuantizerClass; } +inline bool IsNVFP4Quantizers(PyObject *obj) { return Py_TYPE(obj) == NVFP4QuantizerClass; } + inline bool IsFloat8BlockwiseQTensor(PyObject *obj) { return Py_TYPE(obj) == Float8BlockwiseQTensorPythonClass || - Py_TYPE(obj) == Float8BlockwiseQTensorBasePythonClass; + Py_TYPE(obj) == Float8BlockwiseQTensorStoragePythonClass; +} + +inline bool IsNVFP4Tensor(PyObject *obj) { + return Py_TYPE(obj) == NVFP4TensorPythonClass || Py_TYPE(obj) == NVFP4TensorStoragePythonClass; } TensorWrapper NVTETensorFromFloat8Tensor(py::handle tensor, Quantizer *quantizer); @@ -88,6 +95,8 @@ std::unique_ptr CreateMXFP8Params(const py::handle params); TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer *quantization_params); +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } @@ -100,8 +109,13 @@ constexpr std::array custom_types_converters = { std::make_tuple(IsMXFP8Tensor, IsMXFP8Quantizers, NVTETensorFromMXFP8Tensor, CreateQuantizer), std::make_tuple(IsFloat8BlockwiseQTensor, IsFloat8BlockwiseQuantizers, - NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer)}; - + NVTETensorFromFloat8BlockwiseQTensor, CreateQuantizer), +#ifdef __HIP_PLATFORM_AMD__ +}; +#else + std::make_tuple(IsNVFP4Tensor, IsNVFP4Quantizers, NVTETensorFromNVFP4Tensor, + CreateQuantizer)}; +#endif } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 37c13362c..90ed2a99f 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -33,8 +33,20 @@ std::vector make_transpose_shape(const std::vector& shape) { return ret; } +/*! @brief Convert shape for FP4 data by dividing the last dimension by 2 */ +template +std::vector convert_shape_for_fp4(const std::vector& shape) { + std::vector ret; + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + ret.push_back(shape.back() / 2); + return ret; +} + } // namespace +constexpr size_t NVFP4_BLOCK_SIZE = 16; constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -142,7 +154,7 @@ std::pair Float8Quantizer::create_tensor( // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); @@ -347,7 +359,7 @@ std::pair Float8CurrentScalingQuantizer::create_tenso py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); + py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, "quantizer"_a = this->quantizer); @@ -378,10 +390,15 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } -std::pair Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( - const std::vector& shape, DType dtype) { +std::pair +Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, + DType dtype, + std::optional data) { amax.zero_(); - auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + auto out = data.has_value() ? NoneQuantizer(py::none()).create_tensor(shape, dtype, data.value()) + : NoneQuantizer(py::none()).create_tensor(shape, dtype); + TensorWrapper out_cpp = std::move(out.first); + py::object out_py = std::move(out.second); out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); return {std::move(out_cpp), std::move(out_py)}; @@ -499,7 +516,7 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { -#ifdef __HIP_PLATFORM_AMD__ +#ifdef USE_ROCM at::Tensor ws = allocate_amax_workspace(input); TensorWrapper tw = makeTransformerEngineTensor(ws); NVTE_SCOPED_GIL_RELEASE({ @@ -623,7 +640,7 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorBasePythonClass)); + reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); ret = Float8BlockwiseQTensorClass( "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, @@ -909,7 +926,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -943,7 +960,7 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); + py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, "rowwise_scale_inv"_a = rowwise_scale_inv_py, @@ -1105,7 +1122,7 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s auto last_dim = shape.back(); NVTE_CHECK(last_dim % MXFP8_BLOCK_SIZE == 0 && (numel / last_dim) % MXFP8_BLOCK_SIZE == 0, - "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + "MXFP8 requires tensor dims that are divisible by ", MXFP8_BLOCK_SIZE, " (got shape=", shape, ")"); std::vector scale_shape; @@ -1126,4 +1143,583 @@ std::vector MXFP8Quantizer::get_scale_shape(const std::vector& s return scale_shape; } +#ifndef USE_ROCM +NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantizer) { + this->dtype = quantizer.attr("dtype").cast(); + this->with_rht = quantizer.attr("with_rht").cast(); + this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); + this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); + this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + + // Get amax reduction group if needed for NVFP4 AG + const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); + c10::intrusive_ptr amax_reduction_group; + if (with_amax_reduction) { + auto group = quantizer.attr("_canonicalized_amax_reduction_group")(); + NVTE_CHECK(!group.is_none(), "NVFP4Quantizer could not canonicalize amax reduction group"); + amax_reduction_group = group.cast>(); + } + this->with_amax_reduction = with_amax_reduction; + this->amax_reduction_group = amax_reduction_group; + + this->rht_matrix_random_sign_mask_t = quantizer.attr("rht_matrix_random_sign_mask_t").cast(); + this->rht_matrix = quantizer.attr("rht_matrix").cast(); +} + +void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { + // set dtype for rowwise and columnwise data in tensor wrapper + auto rowwise_data = tensor->get_rowwise_data(); + rowwise_data.dtype = static_cast(this->dtype); + + auto columnwise_data = tensor->get_columnwise_data(); + columnwise_data.dtype = static_cast(this->dtype); + + tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), + rowwise_data.shape); + tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), + columnwise_data.shape); +} + +std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { + using namespace pybind11::literals; + + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, "First dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; + const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed + amax_rowwise = at::empty({1}, bit32_tensor_opts); + } + if (columnwise_usage) { + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data_tensor = + at::empty(convert_shape_for_fp4(transpose_shape_int64), bit8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed + amax_columnwise = at::empty({1}, bit32_tensor_opts); + } + + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + auto amax_rowwise_py = py_cast(amax_rowwise, rowwise_usage); + auto amax_columnwise_py = py_cast(amax_columnwise, columnwise_usage); + + // Construct Python NVFP4 tensor + py::object out_py; + if (internal) { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); + out_py = NVFP4TensorClass( + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } else { + py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); + out_py = NVFP4TensorClass( + "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, + "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, + "quantizer"_a = this->quantizer); + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + rowwise_scale_inv_shape); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, + columnwise_scale_inv_shape); + out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( + TensorWrapper& quantized_tensor, DType dtype) { + // Construct tensor + auto shape = convertShape(quantized_tensor.shape()); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + + // Register amax pointer from quantized tensor + void* amax_ptr = quantized_tensor.amax(); + if (amax_ptr == nullptr) { + amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + } + NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); + out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + + // Zero out amax + NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair NVFP4Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + auto amax_rowwise = get_tensor("_amax_rowwise"); + auto amax_columnwise = get_tensor("_amax_columnwise"); + NVTE_CHECK(rowwise_data || columnwise_data, "NVFP4Tensor has no data."); + + // Tensor dimensions, shape means original shape + std::vector shape; + if (columnwise_data) { + shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); + if (rowwise_data) { + auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); + } + + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; + } + } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(convert_shape_for_fp4(shape_int64), opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + if (!amax_rowwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed + amax_rowwise = at::empty({1}, opts); + tensor.attr("_amax_rowwise") = *amax_rowwise; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + if (amax_rowwise) { + amax_rowwise.reset(); + tensor.attr("_amax_rowwise") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_int64_2d = {static_cast(flat_first_dim), + static_cast(flat_last_dim)}; + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto transpose_shape_int64 = make_transpose_shape(shape_int64_2d); + columnwise_data = at::empty(convert_shape_for_fp4(transpose_shape_int64), opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + if (!amax_columnwise) { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + // hadamard amax kernel will zero out pointer with ZeroAmaxKernel + // nvte_compute_amax_with_config will zero out the pointer if needed + amax_columnwise = at::empty({1}, opts); + tensor.attr("_amax_columnwise") = *amax_columnwise; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + if (amax_columnwise) { + amax_columnwise.reset(); + tensor.attr("_amax_columnwise") = py::none(); + } + } + + // Construct C++ tensor + TensorWrapper out_cpp(NVTE_NVFP4_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (columnwise_usage) { + // enforce 2D shape to avoid [S, B, H] shape and B and be 1 + // and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + std::vector shape_2d = {flat_first_dim, flat_last_dim}; + auto col_data_shape_fp4 = make_transpose_shape(shape_2d); + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), DType::kFloat4E2M1, + col_data_shape_fp4); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(amax_columnwise->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, + bool compute_amax) { + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + auto stream = at::cuda::getCurrentCUDAStream(); + + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); + quant_config.set_stochastic_rounding(this->stochastic_rounding); + + // We only need RHT for columnwise usage. + // flat first dim and last dim for multi dimensional input + size_t rows = 1; + for (size_t i = 0; i < input.ndim() - 1; ++i) { + rows *= input.size(i); + } + size_t cols = input.size(input.ndim() - 1); + + TensorWrapper te_rng_state; + if (this->stochastic_rounding) { + const size_t rng_elts_per_thread = 1024; // Wild guess, probably can be tightened + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args = init_philox_state(gen, rng_elts_per_thread); + auto opts = at::TensorOptions().dtype(torch::kInt64).device(torch::kCUDA); + auto rng_state = torch::empty({2}, opts); + philox_unpack(philox_args, static_cast(rng_state.data_ptr())); + te_rng_state = makeTransformerEngineTensor(rng_state); + quant_config.set_rng_state(te_rng_state.data()); + } + + // Restriction for the RHT cast fusion kernel. + bool eligible_for_rht_cast_fusion = + input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; + + // Compute amax. + if (this->with_rht) { + if (input.dtype() != DType::kBFloat16) { + NVTE_CHECK(false, "RHT is only supported for bfloat16 input"); + } + if (this->with_post_rht_amax) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for RHT(input.t) + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_amax(input.data(), out.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + } else { + // raise error since it's not supported yet + NVTE_CHECK(false, "Pre-RHT amax is not supported yet"); + } + } else { // Without RHT + if (compute_amax) { + // Amax pointers + auto rowwise_amax_ptr = out.get_amax().data_ptr; + auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + void* amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + + // Compute amax of input tensor + out.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); + out.set_amax(rowwise_amax_ptr, DType::kFloat32, std::vector{1}); + + // Make sure row-wise and column-wise amaxes match + if (rowwise_amax_ptr != amax_ptr && rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(rowwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + if (columnwise_amax_ptr != amax_ptr && columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(columnwise_amax_ptr, amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } + + // amax reduction + if (this->with_amax_reduction) { + std::vector amax_tensors; + // push amax tensors inside if they need to be reduced + auto make_amax_tensor = [](void* data_ptr) { + return at::from_blob( + data_ptr, std::vector{1}, + [](void*) {}, // deleter doing nothing since it doesn't own the data + at::device(at::kCUDA).dtype(torch::kFloat32)); + }; + if (rowwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_amax().data_ptr)); + } + if (columnwise_usage) { + amax_tensors.push_back(make_amax_tensor(out.get_columnwise_amax().data_ptr)); + } + c10d::AllreduceCoalescedOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + NVTE_SCOPED_GIL_RELEASE( + { this->amax_reduction_group->allreduce_coalesced(amax_tensors, opts)->wait(); }); + } + + if (this->with_rht) { + if (rowwise_usage) { + // For rowwise usage, we need to quantize the input directly, but we need to avoid quantizing columnwise + TensorWrapper out_identity(out.scaling_mode()); + auto out_identity_data = out.get_rowwise_data(); + auto out_identity_scale_inv = out.get_rowwise_scale_inv(); + auto out_identity_amax = out.get_amax(); + out_identity.set_rowwise_data(out_identity_data.data_ptr, + static_cast(out_identity_data.dtype), + out_identity_data.shape); + out_identity.set_rowwise_scale_inv(out_identity_scale_inv.data_ptr, + static_cast(out_identity_scale_inv.dtype), + out_identity_scale_inv.shape); + out_identity.set_amax(out_identity_amax.data_ptr, static_cast(out_identity_amax.dtype), + out_identity_amax.shape); + + NVTE_SCOPED_GIL_RELEASE( + { nvte_quantize_v2(input.data(), out_identity.data(), quant_config, stream); }); + } + + if (columnwise_usage) { + // Get the output columnwise data, scale_inv, and amax + auto out_columnwise_data = out.get_columnwise_data(); + auto out_columnwise_scale_inv = out.get_columnwise_scale_inv(); + // NOTE: should already be populated. + auto out_columnwise_amax = out.get_columnwise_amax(); + + // Create a wrapper for the columnwise output, as the rowwise output. + // The reason is due to the input `rht_output_t` is already in the transposed layout. + // Thus, we only need a rowwise quantization to generate the columnwise output. + TensorWrapper out_transpose(out.scaling_mode()); + // Note: since we are faking columnwise tensor into rowwise, the flat first dim check will fail + // need to convert the shape to 2D here + auto colwise_data_shape = out_columnwise_data.shape; + std::vector colwise_data_shape_2d; + // shape could be [512, 32, 64], that's actually 512, 32, 128 because 2 FP4 take 1 byte + // the 2D shape should be [512, 32*128], but columnwise data shape expect last dim to be halved again + // so the multiple 2 get cancelled out + colwise_data_shape_2d.push_back(colwise_data_shape.data[0]); + size_t last_dim = 1; + for (size_t i = 1; i < colwise_data_shape.ndim; ++i) { + last_dim *= colwise_data_shape.data[i]; + } + colwise_data_shape_2d.push_back(last_dim); + + out_transpose.set_rowwise_data(out_columnwise_data.data_ptr, + static_cast(out_columnwise_data.dtype), + colwise_data_shape_2d); + out_transpose.set_rowwise_scale_inv(out_columnwise_scale_inv.data_ptr, + static_cast(out_columnwise_scale_inv.dtype), + out_columnwise_scale_inv.shape); + out_transpose.set_amax(out_columnwise_amax.data_ptr, + static_cast(out_columnwise_amax.dtype), + out_columnwise_amax.shape); + + if (!eligible_for_rht_cast_fusion) { + // Invoking fallback RHT kernel. + + // If using RHT, then amax will be computed in the RHT step + // If not using RHT, then amax will be computed based on input x + at::Tensor rht_output_t; // The RHT(x_t) output, in columnwise layout + // This wrapper is going to be passed as input to the quantization kernel. + TensorWrapper rht_output_t_cpp; // Wrapper to contain the RHT(x) and RHT(x_t) outputs + rht_output_t = + allocateTorchTensor(static_cast(cols), static_cast(rows), input.dtype()); + // NOTE (frsun): This is non-intuitive, we are writing the + // result of transposed RHT to the output of rowwise. + rht_output_t_cpp.set_rowwise_data(rht_output_t.data_ptr(), input.dtype(), + std::vector{cols, rows}); + + NVTE_SCOPED_GIL_RELEASE({ + // Perform the RHT(input.t), and write to rht_output_cpp.columnwise. + nvte_hadamard_transform(input.data(), rht_output_t_cpp.data(), 0, + this->rht_matrix_random_sign_mask_t, stream); + }); + + // Quantize kernel will treat everything as rowwise input/output, which is + // intended. + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(rht_output_t_cpp.data(), out_transpose.data(), quant_config, stream); + }); + } else { + // RHT cast fusion kernel. + NVTE_CHECK(this->rht_matrix.defined() && this->rht_matrix.numel() > 0, + "RHT matrix is not set"); + auto rht_matrix_nvte = makeTransformerEngineTensor(this->rht_matrix); + NVTE_SCOPED_GIL_RELEASE({ + nvte_hadamard_transform_cast_fusion_columnwise( + input.data(), out_transpose.data(), rht_matrix_nvte.data(), quant_config, stream); + }); + } + } + } else { + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); + } +} + +void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true); +} + +void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + // Update output tensor amaxes with input tensor amax + auto input_amax_ptr = input.amax(); + auto output_rowwise_amax_ptr = out.get_amax().data_ptr; + auto output_columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; + NVTE_CHECK(input_amax_ptr != nullptr || + (output_rowwise_amax_ptr == nullptr && output_columnwise_amax_ptr == nullptr), + "Input tensor does not have pre-computed amax"); + if (input_amax_ptr != output_rowwise_amax_ptr && input_amax_ptr != nullptr && + output_rowwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_rowwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + if (input_amax_ptr != output_columnwise_amax_ptr && input_amax_ptr != nullptr && + output_columnwise_amax_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(output_columnwise_amax_ptr, input_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, at::cuda::getCurrentCUDAStream())); + } + input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + + // Perform quantization + this->quantize_impl(input, out, std::nullopt, false); +} + +std::vector NVFP4Quantizer::get_scale_shape(const std::vector& shape, + bool columnwise) const { + size_t numel = 1; + for (auto s : shape) { + numel *= s; + } + + auto last_dim = shape.back(); + auto flat_first_dim = numel / last_dim; + + NVTE_CHECK(last_dim % NVFP4_BLOCK_SIZE == 0, "Last dim for NVFP4 must be divisible by ", + NVFP4_BLOCK_SIZE, " (got dim=", last_dim, ")"); + NVTE_CHECK(flat_first_dim % NVFP4_BLOCK_SIZE == 0, + "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, + " (got shape=", shape, ")"); + + std::vector scale_shape; + + bool rowwise_usage = !columnwise; + + if (rowwise_usage) { + // rowwise scaling factor shape + size_t sinv0 = roundup(flat_first_dim, 128); + size_t sinv1 = roundup(last_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } else { + // columnwise scaling factor shape + size_t sinv0 = roundup(last_dim, 128); + size_t sinv1 = roundup(flat_first_dim / NVFP4_BLOCK_SIZE, 4); + scale_shape = {sinv0, sinv1}; + } + return scale_shape; +} +#endif + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index cb2121a45..368e9dcdf 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -116,6 +116,46 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, Quantizer return ret; } +TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) { + const DType dtype = tensor.attr("_fp4_dtype").cast(); + + auto ret = TensorWrapper(NVTE_NVFP4_1D_SCALING); + + bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); + bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); + + NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); + + // Row-scaled data + if (rowwise_usage) { + const auto &data = tensor.attr("_rowwise_data").cast(); + const auto &scale_inv = tensor.attr("_rowwise_scale_inv").cast(); + const auto &amax_rowwise = tensor.attr("_amax_rowwise").cast(); + ret.set_rowwise_data(data.data_ptr(), dtype, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, getTensorShape(scale_inv)); + ret.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); + } + + // Column-scaled data + if (columnwise_usage) { + const auto &data = tensor.attr("_columnwise_data").cast(); + const auto &scale_inv = tensor.attr("_columnwise_scale_inv").cast(); + const auto &amax_columnwise = tensor.attr("_amax_columnwise").cast(); + ret.set_columnwise_data(data.data_ptr(), DType::kFloat4E2M1, + convert_shape_back_from_fp4(getTensorShape(data), false)); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat8E4M3, + getTensorShape(scale_inv)); + ret.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, + getTensorShape(amax_columnwise)); + } + + // Quantizer state + quantizer->set_quantization_params(&ret); + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 44b636930..3948c6403 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -11,6 +11,7 @@ #include "util.h" #include "common.h" +#include "common/common.h" std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper& input, bool rowwise) { @@ -18,22 +19,31 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap if (input.scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING && + input.scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } - NVTE_CHECK(input.element_size() == 1, "8-bit input required for swizzling scaling factors."); + NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8, + "4-bit or 8-bit input required for swizzling scaling factors."); + + const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING; NVTEBasicTensor scale_inv; + NVTEShape nvte_input_shape; if (rowwise) { + nvte_input_shape = input.shape(); scale_inv = input.get_rowwise_scale_inv(); } else { + nvte_input_shape = input.get_columnwise_data().shape; scale_inv = input.get_columnwise_scale_inv(); } - auto input_shape = nvte_shape_to_vector(input.shape()); + auto input_shape = nvte_shape_to_vector(nvte_input_shape); auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape."); + // Allocate memory for swizzled output. auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); std::vector scale_inv_shape_int; @@ -44,37 +54,33 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap void* scale_inv_dptr = scale_inv.data_ptr; void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); - // Reconstruct input only to avoid swizzling both directions if not needed. - // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper input_cu(input.scaling_mode()); + transformer_engine::TensorWrapper output_cu(input.scaling_mode()); + + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + if (rowwise) { - input_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_rowwise_data(input.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); - output_cu.set_columnwise_data(input.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shape); + input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape); + output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } // Launch kernel nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } else { - input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shape); + input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape); } return swizzled_scale_inv; @@ -96,10 +102,14 @@ std::optional multi_tensor_swizzle_scaling_factors( if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { NVTE_ERROR("Invalid scaling mode for swizzle."); - } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING && + tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) { return std::nullopt; } + const auto scaling_mode = tensors.front().scaling_mode(); + const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING; + std::vector wrappers; std::vector input_tensors, output_tensors; @@ -127,39 +137,44 @@ std::optional multi_tensor_swizzle_scaling_factors( // Allocate full buffer auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + const auto input_dtype = + (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3; + const auto scale_inv_dtype = + (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0; + for (size_t i = 0; i < tensors.size(); ++i) { auto& tensor = tensors[i]; void* scale_inv_dptr = scale_inv_dptrs[i]; void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); - auto input_shape = nvte_shape_to_vector(tensor.shape()); - + // auto input_shape = nvte_shape_to_vector(tensor.shape()); + NVTEShape nvte_input_shape; + if (rowwise) { + nvte_input_shape = tensor.shape(); + } else { + nvte_input_shape = tensor.get_columnwise_data().shape; + } + auto input_shape = nvte_shape_to_vector(nvte_input_shape); // Reconstruct input only to avoid swizzling both directions if not needed. // Use any 8 bit type, it's irrelevant. - transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); - transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper input_cu(scaling_mode); + transformer_engine::TensorWrapper output_cu(scaling_mode); if (rowwise) { - input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); - input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); - output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); // Set the swizzled scaling factor to the original tensor. - tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); } else { - input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, - input_shape); - input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, - scale_inv_shapes[i]); - output_cu.set_columnwise_data(tensor.columnwise_dptr(), - transformer_engine::DType::kFloat8E4M3, input_shape); - output_cu.set_columnwise_scale_inv( - swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape); + output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); // Set the swizzled scaling factor to the original tensor. - tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, - transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, + scale_inv_shapes[i]); } input_tensors.emplace_back(input_cu.data()); @@ -175,4 +190,72 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input, + bool rowwise) { + using namespace transformer_engine::pytorch; + using transformer_engine::DIVUP; + + // Check input tensor + const NVTEScalingMode scaling_mode = input.scaling_mode(); + NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D, + "Input tensor must be a block scaling tensor"); + + // Get tensor data + NVTEBasicTensor data; + size_t data_flat_first_dim = 1; + size_t data_flat_last_dim = 1; + if (rowwise) { + data = input.get_rowwise_data(); + for (int i = 0; i < data.shape.ndim - 1; ++i) { + data_flat_first_dim *= data.shape.data[i]; + } + data_flat_last_dim = data.shape.data[data.shape.ndim - 1]; + } else { + data = input.get_columnwise_data(); + data_flat_first_dim = data.shape.data[0]; + for (int i = 1; i < data.shape.ndim; ++i) { + data_flat_last_dim *= data.shape.data[i]; + } + } + NVTEShape data_shape{}; + data_shape.data[0] = data_flat_first_dim; + data_shape.data[1] = data_flat_last_dim; + data_shape.ndim = 2; + + // Recreate input tensor with rowwise usage + transformer_engine::TensorWrapper input_cu(scaling_mode); + input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + const NVTEBasicTensor scale_inv = + rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv(); + input_cu.set_rowwise_scale_inv( + scale_inv.data_ptr, static_cast(scale_inv.dtype), scale_inv.shape); + + // Create output tensor + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape); + // Output swizzled mxfp8 scaling factor dimensions + const size_t swizzled_scale_inv_first_dim = DIVUP(data_flat_first_dim, 128) * 128; + const size_t swizzled_scale_inv_last_dim = DIVUP(data_flat_last_dim, 128) * 4; + // Allocate memory for swizzled mxfp8 scaling factors + const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA); + at::Tensor swizzled_scale_inv = at::empty( + std::vector{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options); + // Set rowwise scaling factors on output + void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0); + NVTEShape swizzled_scale_inv_shape{}; + swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim; + swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim; + swizzled_scale_inv_shape.ndim = 2; + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + swizzled_scale_inv_shape); + + // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format + nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(), + at::cuda::getCurrentCUDAStream()); + + // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor + // for it to be kept alive during the GEMM + input = std::move(output_cu); + return swizzled_scale_inv; +} #endif //!USE_ROCM diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 621cc1db8..9a46ae86d 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -31,6 +31,17 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); +/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place. + * + * If rowwise==false, the columnwise data will be reinterpreted as rowwise data to avoid + * transposing it in memory. Due to differences in how block scaling and mxfp8 store data, + * this requires the calling code to treat the output tensor as having been tranposed in this case. + * + * Returns the swizzled scaling factor of the converted mxfp8 tensor. + * The returned swizzled scaling factor tensor should be kept alive during the GEMM. + */ +at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper &input, + bool rowwise); #endif //!USE_ROCM #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ diff --git a/transformer_engine/pytorch/custom_recipes/__init__.py b/transformer_engine/pytorch/custom_recipes/__init__.py new file mode 100644 index 000000000..6e859ba5d --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Experimental features and APIs.""" diff --git a/transformer_engine/pytorch/custom_recipes/gemm.py b/transformer_engine/pytorch/custom_recipes/gemm.py new file mode 100644 index 000000000..cc98a8a57 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/gemm.py @@ -0,0 +1,137 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""GEMM API that enables custom GEMM logic for custom quantization recipes.""" + +from typing import Iterable, Optional + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization import ( + MMParams, + GEMMType, +) +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer +from transformer_engine.pytorch.tensor.utils import is_custom + + +def custom_gemm( + A: QuantizedTensorStorage, + B: QuantizedTensorStorage, + workspace: torch.Tensor, # pylint: disable=unused-argument + out_dtype: Optional[torch.dtype] = None, + quantization_params: Optional[Quantizer] = None, # pylint: disable=unused-argument + gelu: bool = False, # pylint: disable=unused-argument + gelu_in: torch.Tensor = None, # pylint: disable=unused-argument + accumulate: bool = False, # pylint: disable=unused-argument + layout: str = "TN", + out: Optional[torch.Tensor] = None, # pylint: disable=unused-argument + bias: Optional[torch.Tensor] = None, + use_split_accumulator: bool = False, + grad: bool = False, +) -> Iterable[Optional[torch.Tensor]]: + """Dispatch GEMM to quantizer's qgemm method.""" + assert is_custom(A) and is_custom(B), "A and B must be custom tensors" + + A, B = B, A + + # Determine GEMM type based on grad flag and layout + if not grad: + gemm_type = GEMMType.FPROP + else: + if layout == "NN": + gemm_type = GEMMType.DGRAD + elif layout == "NT": + gemm_type = GEMMType.WGRAD + else: + # Default to FPROP for other layouts + gemm_type = GEMMType.FPROP + + # Extract quantizer from QuantizedTensor to get qgemm logic + # TODO(etsykunov): make it more flexible, what if we might want to use gemm logic from B._quantizer? + quantizer = None + if hasattr(A, "_quantizer") and A._quantizer is not None: + quantizer = A._quantizer + elif hasattr(B, "_quantizer") and B._quantizer is not None: + quantizer = B._quantizer + else: + raise ValueError("No quantizer found in QuantizedTensor objects") + + # Create MMParams + m_params = MMParams( + out_dtype=out_dtype, + use_split_accumulator=use_split_accumulator, + ) + out_dtype = A.dtype if m_params.out_dtype is None else m_params.out_dtype + + if gemm_type == GEMMType.FPROP: + qx, sx = A.data, A.scale + qw, sw = B.data, B.scale + assert qx is not None + assert sx is not None + assert qw is not None + assert sw is not None + assert A.original_shape is not None + + # Call quantizer's qgemm method + result = quantizer.qgemm( + qx, + qw, + m_params, + out_dtype, + sx, + sw, + bias, + gemm_type=GEMMType.FPROP, + qresult_x=A, + qresult_w=B, + ) + if len(A.original_shape) > 2: + # Original input was 3D, so we need to reshape result back to 3D + batch_size = A.original_shape[0] + seq_len = A.original_shape[1] + result = result.view(batch_size, seq_len, result.shape[-1]) + elif gemm_type == GEMMType.DGRAD: + qdy, sdy = A.data, A.scale + qw_t, sw_t = B.data_t, B.scale_t + assert qdy is not None + assert sdy is not None + assert qw_t is not None + assert sw_t is not None + + result = quantizer.qgemm( + qdy, + qw_t, + m_params, + out_dtype, + sdy, + sw_t, + None, + gemm_type=GEMMType.DGRAD, + qresult_x=A, + qresult_w=B, + ) + elif gemm_type == GEMMType.WGRAD: + qdy_t, sdy_t = A.data_t, A.scale_t + qx_t, sx_t = B.data_t, B.scale_t + assert qdy_t is not None + assert sdy_t is not None + assert qx_t is not None + assert sx_t is not None + + result = quantizer.qgemm( + qdy_t, + qx_t, + m_params, + out_dtype, + sdy_t, + sx_t, + None, + gemm_type=GEMMType.WGRAD, + qresult_x=A, + qresult_w=B, + ) + + # Return in the same format as general_gemm + return result, None, None, None diff --git a/transformer_engine/pytorch/custom_recipes/quantization.py b/transformer_engine/pytorch/custom_recipes/quantization.py new file mode 100644 index 000000000..876ca7fcb --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization.py @@ -0,0 +1,29 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization API for experimental middleware between Transformer Engine and Kitchen.""" + +from __future__ import annotations +import dataclasses +import enum + +import torch + + +@enum.unique +class GEMMType(enum.Enum): + """Type of GEMM operation being performed.""" + + FPROP = "fprop" + DGRAD = "dgrad" + WGRAD = "wgrad" + + +@dataclasses.dataclass(frozen=True) +class MMParams: + """Matrix multiplication parameters.""" + + out_dtype: torch.dtype | None = None + # Use split accumulator for more accurate FP8 GEMM + use_split_accumulator: bool = True diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py new file mode 100644 index 000000000..1ce9079eb --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -0,0 +1,887 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""NVFP4 recipe reference implementation.""" + +import dataclasses +from typing import Optional, Tuple, Union + +import torch + +from transformer_engine.pytorch.custom_recipes import quantization +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage, Quantizer + + +def nvfp4_ref_rht_2d_quantizer_factory(role): + """ + Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). + + Usage with CustomRecipe and fp8_autocast: + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) + with fp8_autocast(fp8_recipe=custom_recipe): + output = model(input) + """ + if role == "linear_input": + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) + if role == "linear_weight": + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(16, 16), + pow_2_scales=False, + with_rht=False, + ) + if role == "linear_grad_output": + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) + return None + + +def cast_to_fp4x2(x): + """Quantize a tensor to FP4 E2M1 and store in a byte tensor""" + + result = torch.zeros_like(x, dtype=torch.uint8) + result[(x >= 0.0) & (x <= 0.25)] = 0 + result[(x > 0.25) & (x < 0.75)] = 1 + result[(x >= 0.75) & (x <= 1.25)] = 2 + result[(x > 1.25) & (x < 1.75)] = 3 + result[(x >= 1.75) & (x <= 2.5)] = 4 + result[(x > 2.5) & (x < 3.5)] = 5 + result[(x >= 3.5) & (x <= 5.0)] = 6 + result[x > 5.0] = 7 + + result[(x >= -0.25) & (x < -0.0)] = 8 + result[(x < -0.25) & (x > -0.75)] = 9 + result[(x <= -0.75) & (x >= -1.25)] = 10 + result[(x < -1.25) & (x > -1.75)] = 11 + result[(x <= -1.75) & (x >= -2.5)] = 12 + result[(x < -2.5) & (x > -3.5)] = 13 + result[(x <= -3.5) & (x >= -5.0)] = 14 + result[x < -5.0] = 15 + + return result[:, ::2] + result[:, 1::2] * 16 + + +def cast_from_fp4x2(x, dq_dtype): + """Dequantize FP4 E2M1 tensor that has been represented in a byte tensor""" + fp4_values = torch.tensor( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + device=x.device, + dtype=dq_dtype, + ) + + # Convert to long integers for indexing + second_bit = torch.div(x, 16, rounding_mode="floor").to(torch.long) + first_bit = (x - second_bit * 16).to(torch.long) + + # Use the long integers to index fp4_values + first_bit_values = fp4_values[first_bit] + second_bit_values = fp4_values[second_bit] + + result = torch.zeros( + (first_bit_values.shape[0], first_bit_values.shape[1] * 2), + device=x.device, + dtype=dq_dtype, + ) + result[:, ::2] = first_bit_values + result[:, 1::2] = second_bit_values + + return result + + +def cast_to_e8(decode_scale): + """Cast to a value that is representable in FP8 E8M0. + + The result is in FP32, not FP8 E8M0. + """ + max_exponent = torch.tensor(127, device=decode_scale.device, dtype=torch.float32) + exponent = torch.ceil(torch.log2(decode_scale)) + exponent = torch.clamp(exponent, min=-max_exponent, max=max_exponent) + + return torch.tensor(2.0, device=decode_scale.device, dtype=torch.float32) ** exponent + + +def cast_to_e4m3(decode_scale, global_amax): + """Scale and cast to FP8 E4M3. + + decode_scale is actually the encoding scaling factor. global_amax + can be any data tensor and not just the amax. + + TODO(etsykunov): Make less unintuitive. + """ + decode_scale = decode_scale * global_amax + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=decode_scale.device, dtype=torch.float32) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + return decode_scale.to(torch.float8_e4m3fn) + + +def high_precision_gemm_ref( + a: torch.Tensor, + b: torch.Tensor, + out_dtype: torch.dtype, + accumulate: bool = False, + is_a_transposed: bool = False, + is_b_transposed: bool = False, + out: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + scale_alpha: float = 1.0, +) -> torch.Tensor: + """GEMM implementation with unquantized data""" + # Handle transpositions + mat1, mat2 = a, b + if is_a_transposed: + mat1 = a.T + if is_b_transposed: + mat2 = b.T + + # Ensure dtype compatibility for torch.addmm + mat1 = mat1.to(out_dtype) + mat2 = mat2.to(out_dtype) + + # Determine output shape + y_shape = (mat1.size(0), mat2.size(1)) + + if bias is not None: + assert not accumulate, "Bias is not supported with accumulation" + bias = bias.to(out_dtype) + # With bias case + if out_dtype == torch.float32: + y_ref = torch.addmm(bias.repeat(mat1.size(0), 1), mat1, mat2, beta=1, alpha=1) + else: + y_ref = torch.addmm(bias, mat1, mat2, beta=1, alpha=scale_alpha) + else: + # Without bias case + if accumulate and out is not None: + y_ref = out.clone().to(out_dtype) + else: + y_ref = torch.zeros(y_shape, dtype=out_dtype, device=a.device) + torch.addmm(y_ref, mat1, mat2, beta=1, alpha=scale_alpha, out=y_ref) + + return y_ref + + +@dataclasses.dataclass +class NVFP4TensorRef(QuantizedTensorStorage): + """NVFP4 tensor for middleware between Transformer Engine and Kitchen. + + Custom container to hold quantization result, including quantized tensor, optional + transposed quantized tensor, and corresponding decoding scales. + + data: torch.Tensor + the quantized tensor. + scale: torch.Tensor + the decoding scale for the quantized tensor. Shape depends on the scaling granularity. + - if scaling type is PER_TENSOR, it should be a 1D scalar tensor. + data_t: torch.Tensor + the transposed quantized tensor (computed lazily if needed). + scale_t: torch.Tensor + the decoding scale for the transposed quantized tensor. + dtype: torch.dtype + nominal tensor datatype. + device: torch.device + device of the tensor. + quant_dtype: Union[utils.Fp4Formats, torch.dtype] + low precision tensor datatype. + original_shape: Tuple[int, ...] + original shape of the tensor. + _quantizer: Quantizer + Builder class for quantized tensor. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + data_t: Optional[torch.Tensor] = None + scale_t: Optional[torch.Tensor] = None + global_amax_row: Optional[torch.Tensor] = None + global_amax_col: Optional[torch.Tensor] = None + + dtype: Optional[torch.dtype] = None + device: Optional[torch.device] = None + quant_dtype: Optional[Union[utils.Fp4Formats, torch.dtype]] = None + original_shape: Optional[Tuple[int, ...]] = None + _quantizer: Optional[Quantizer] = None + + @property + def custom(self) -> bool: + """Flag to indicate this quantized tensor is custom.""" + return True + + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: + """Prepare the quantization result for saving for backward""" + tensors = [self.data, self.data_t, self.scale, self.scale_t] + self.data = None + self.data_t = None + self.scale = None + self.scale_t = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the quantization result from the saved tensors""" + self.data = tensors[0] + self.data_t = tensors[1] + self.scale = tensors[2] + self.scale_t = tensors[3] + return tensors[4:] + + # Compatibility + @property + def _data(self): + return self.data + + @_data.setter + def _data(self, value): + self.data = value + + @property + def _scale_inv(self): + return self.scale + + @_scale_inv.setter + def _scale_inv(self, value): + self.scale = value + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"dtype={self.dtype}, " + f"device={self.device}, " + f"quant_dtype={self.quant_dtype}, " + f"original_shape={self.original_shape}" + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """Generate or remove quantized data based on provided usage.""" + has_data = self.data is not None + has_data_transpose = self.data_t is not None + needs_data = has_data + needs_data_transpose = has_data_transpose + + if rowwise_usage is not None: + needs_data = rowwise_usage + if columnwise_usage is not None: + needs_data_transpose = columnwise_usage + + # Generate data that is required + if needs_data and not has_data: + raise RuntimeError("Cannot generate FP4 data, even from FP4 data transpose") + if needs_data_transpose and not has_data_transpose: + if not has_data: + raise RuntimeError("FP4 data is required to generate FP4 data transpose") + self._create_transpose() + + # Delete data that is not required + if not needs_data: + self.data = None + if not needs_data_transpose: + self.data_t = None + + def _create_transpose(self): + """Create transposed quantized tensor""" + if not self.data.is_contiguous(): + self.data = self.data.contiguous() + self.data_t = self.data.t().contiguous() + self.scale_t = self.scale + + def size(self, *args, **kwargs): # pylint: disable=unused-argument + """Return the original tensor shape, not the internal packed data shape. + + FP4 quantization packs two 4-bit values into each 8-bit value, which reduces + the second dimension by half. This method returns the logical shape that + users expect, not the internal packed storage shape. + """ + assert self.original_shape is not None + return torch.Size(self.original_shape) + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded signs for Hadamard transform""" + return torch.tensor( + [1.0, 1.0, 1.0, -1.0, 1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, -1.0, 1.0, -1.0, -1.0], + dtype=torch.float32, + ) + + +class NVFP4QuantizerRef(Quantizer): + """NVFP4 quantizer for middleware between Transformer Engine and Kitchen""" + + def __init__( + self, + dtype: utils.Fp4Formats, + rowwise: bool = True, + columnwise: bool = True, + pow_2_scales: bool = False, + eps: float = 0.0, + quant_tile_shape: Tuple[int, int] = (1, 16), + with_rht: bool = False, + with_random_sign_mask: bool = True, + ): + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.internal = True + + self.dtype = dtype + self.pow_2_scales = pow_2_scales + self.eps = eps + self.quant_tile_shape = quant_tile_shape + self.with_rht = with_rht + self.with_random_sign_mask = with_random_sign_mask + + @property + def custom(self) -> bool: + """Flag to indicate this quantizer is custom.""" + return True + + @staticmethod + def _build_hadamard_matrix( + size: int, device: torch.device, dtype: torch.dtype, with_random_sign_mask: bool = True + ) -> torch.Tensor: + """Construct a Hadamard matrix of given power-of-two size with entries +-1. + + Uses Sylvester construction to avoid SciPy dependency. + """ + assert (size & (size - 1)) == 0, "Hadamard size must be a power of two" + h = torch.ones((1, 1), device=device, dtype=torch.float32) + while h.shape[0] < size: + h = torch.cat( + [ + torch.cat([h, h], dim=1), + torch.cat([h, -h], dim=1), + ], + dim=0, + ) + if with_random_sign_mask: + sign_mat = get_wgrad_sign_vector().to(device) * torch.eye( + size, device=device, dtype=torch.float32 + ) + h = sign_mat @ h + return h.to(dtype) + + def _apply_rht(self, x: torch.Tensor) -> torch.Tensor: + """Apply randomized Hadamard transform without random signs (reference path). + + This matches the reference used in tests: x_reshaped @ (H * (1/sqrt(g))). + """ + # Only apply when enabled + if not self.with_rht: + return x + + # RHT dimension equals the quantization tile length (NVFP4 uses 16) + rht_dim = self.quant_tile_shape[1] + assert ( + x.shape[-1] % rht_dim == 0 + ), f"Inner dimension {x.shape[-1]} must be divisible by hadamard dimension {rht_dim}" + + # Build H and scale + H = self._build_hadamard_matrix(rht_dim, x.device, x.dtype, self.with_random_sign_mask) + scale = 1.0 / float(rht_dim) ** 0.5 + + # Perform blockwise transform along the last dimension + original_shape = x.shape + x_mat = x.contiguous().view(-1, rht_dim) + # Random sign matrix is identity in this reference (no sign flipping) + transform = H * scale + out = x_mat @ transform + return out.view(original_shape) + + @staticmethod + def _recover_swizzled_scales( + swizzled_scale: bool, scale: torch.Tensor, m: int, n: int, block_length: int + ) -> torch.Tensor: + if not swizzled_scale: + return scale + rounded_m = utils.roundup_div(m, 128) * 128 + scale_n = utils.roundup_div(n, block_length) + rounded_n = utils.roundup_div(scale_n, 4) * 4 + # Recover swizzled scaling factor layout -> linear layout + tmp = torch.reshape(scale, (rounded_m // 128, rounded_n // 4, 32, 4, 4)) + # after permutation, the layout is [rounded_m // 128, 4, 32, rounded_n // 4, 4] + tmp = torch.permute(tmp, (0, 3, 2, 1, 4)) + result = torch.reshape(tmp, (rounded_m, rounded_n)) + return result[:m, :scale_n] + + @classmethod + def _quantize_blockwise_reference( + cls, + x: torch.Tensor, + global_amax: torch.Tensor, + tile_len_x: int, + tile_len_y: int, + *, + pow_2_scales: bool, + eps: float, # pylint: disable=unused-argument + ) -> Tuple[torch.Tensor, torch.Tensor]: + + assert x.ndim == 2 + using_2d_quantization = tile_len_x == 16 and tile_len_y == 16 + m, n = x.shape + # Compute vec_max based on the original x (before reshape) + # For 1D quantization: amax over each row chunk of 16 + # For 2D quantization: amax over each 16x16 block, but output shape is still (128, 8, 1), filled with block amax + if using_2d_quantization: + # x shape: (128, 128) + x_blocks = ( + x.unfold(0, tile_len_y, tile_len_y) + .unfold(1, tile_len_x, tile_len_x) + .to(torch.float32) + ) # (8, 8, 16, 16) + block_amax = torch.amax(torch.abs(x_blocks), dim=(-1, -2)) # (8, 8) + # Now, expand to (128, 8, 1) by repeating each block_amax for 16 rows + vec_max = block_amax.repeat_interleave(tile_len_y, dim=0).unsqueeze(-1) # (128, 8, 1) + else: + # x shape: (128, 128) + x_reshaped = x.view(m, n // tile_len_x, tile_len_x) # (128, 8, 16) + vec_max = torch.amax(torch.abs(x_reshaped), dim=-1, keepdim=True).to( + torch.float32 + ) # (128, 8, 1) + x = x.view(m, n // tile_len_x, tile_len_x) + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) + + if pow_2_scales: + decode_scale = cast_to_e8(decode_scale) + encode_scale = torch.div( + torch.tensor(1.0, device=x.device, dtype=torch.float32), + decode_scale.to(torch.float32), + ) + else: + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.min( + global_encode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=global_encode_scale.device, + dtype=torch.float32, + ), + ) + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + global_decode_scale = torch.div(1.0, global_encode_scale) + + decode_scale = decode_scale * global_encode_scale + decode_scale = torch.min( + decode_scale, + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + decode_scale = torch.clamp(decode_scale, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX) + decode_scale = decode_scale.to(torch.float8_e4m3fn) + + encode_scale = torch.min( + torch.div(1.0, decode_scale.to(torch.float32) * global_decode_scale), + torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale.device, + dtype=torch.float32, + ), + ) + + scaled_x = x.to(torch.float32) * encode_scale + + clipped_x = torch.clamp(scaled_x, -FLOAT4_E2M1_MAX, FLOAT4_E2M1_MAX).reshape(m, n) + + return cast_to_fp4x2(clipped_x), decode_scale.squeeze(-1) + + @staticmethod + def _pad_tensor( + tensor: torch.Tensor, row_divisor: Optional[int], col_divisor: Optional[int] + ) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = tensor.shape + padding_needed_rows = 0 + padding_needed_cols = 0 + + if row_divisor is not None and M % row_divisor != 0: + padding_needed_rows = row_divisor - (M % row_divisor) + # Check and calculate column padding if col_divisor is provided + if col_divisor is not None and N % col_divisor != 0: + padding_needed_cols = col_divisor - (N % col_divisor) + + # Return original tensor if no padding is needed + if padding_needed_rows == 0 and padding_needed_cols == 0: + return tensor + + # pad the tensor + out = torch.nn.functional.pad( + tensor, + (0, padding_needed_cols, 0, padding_needed_rows), + mode="constant", + value=0.0, + ).contiguous() + + return out + + @staticmethod + def _rm_pad_tensor(tensor: torch.Tensor, original_size: tuple[int, ...]) -> torch.Tensor: + + assert tensor.dim() == 2, "only supports 2D tensors" + M, N = original_size + out = tensor[:M, :N].contiguous() + return out + + def _quantize(self, tensor: torch.Tensor) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + torch.Tensor, + torch.Tensor, + ]: + """ + Python implementation of microblock FP4 quantization. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to quantize (should be 2D) + + Returns + ------- + Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor], torch.Tensor] + (qx, sx, qx_t, sx_t, global_amax) where: + - qx: quantized data in row-major order (if rowwise_usage), None otherwise + - sx: scale tensor for qx (if rowwise_usage), None otherwise + - qx_t: quantized data in column-major order (if columnwise_usage), None otherwise + - sx_t: scale tensor for qx_t (if columnwise_usage), None otherwise + - global_amax_row, global_amax_col: global amax tensors + """ + if self.pow_2_scales: + assert self.quant_tile_shape == ( + 1, + 32, + ), "MXFP4 only supports 1x32 tile shape." + # TODO(etsykunov): Fix bug where global_amax_row and + # global_amax_col are not defined + # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) + else: + assert self.quant_tile_shape in ( + (1, 16), + (16, 16), + ), "NVFP4 only supports 1x16 or 16x16 tile shape." + # Prepare inputs once so we can reuse for both amax and quantization + # Row-input will always be the original input. + row_input = tensor + col_input = ( + self._apply_rht(tensor.t().contiguous()) + if self.with_rht + else tensor.t().contiguous() + ) + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) + + transpose_scales = False + + M, N = tensor.shape + if self.rowwise_usage: + x_input = row_input + x_padded = self._pad_tensor( + x_input, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx, sx = self._quantize_blockwise_reference( + x_padded, + global_amax_row, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + if transpose_scales: + sx = sx.T + + qx = self._rm_pad_tensor(qx, (M, N // 2)) + + else: + qx = None + sx = None + + if self.columnwise_usage: + x_t = col_input + x_t_padded = self._pad_tensor( + x_t, row_divisor=self.quant_tile_shape[0], col_divisor=self.quant_tile_shape[1] + ) + + qx_t, sx_t = self._quantize_blockwise_reference( + x_t_padded, + global_amax_col, + self.quant_tile_shape[1], + self.quant_tile_shape[0], + pow_2_scales=self.pow_2_scales, + eps=self.eps, + ) + + qx_t = self._rm_pad_tensor(qx_t, (N, M // 2)) + + if transpose_scales: + sx_t = sx_t.T + else: + qx_t = None + sx_t = None + + return qx, sx, qx_t, sx_t, global_amax_row, global_amax_col + + def quantize( + self, + tensor: torch.Tensor, + **kwargs, # pylint: disable=unused-argument + ) -> NVFP4TensorRef: + # sanity checks + assert tensor.dtype in utils.HIGH_PRECISION_FLOAT_DTYPES, "Unsupported input dtype." + + # Make it work with 3D tensors + original_shape = tensor.shape + if tensor.ndim > 2: + tensor = tensor.view(-1, tensor.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(tensor) + + return NVFP4TensorRef( + data=qx, + scale=sx, + data_t=qx_t, + scale_t=sx_t, + global_amax_row=global_amax_row, + global_amax_col=global_amax_col, + dtype=tensor.dtype, + device=tensor.device, + quant_dtype=self.dtype, + _quantizer=self, + original_shape=original_shape, + ) + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensorStorage, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensorStorage: + """Update the quantized tensor with the given tensor in-place + + Parameters + ---------- + src: torch.Tensor + Source tensor to copy from + dst: QuantizedTensorStorage + Destination QuantizedTensorStorage to update + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + """ + # Handle noop flag + if noop_flag is not None and noop_flag.item() != 0: + return dst + + # Make sure input is in expected format + if not src.is_contiguous(): + src = src.contiguous() + + # Store the original shape and reshape for processing + original_shape = src.shape + if src.ndim > 2: + src = src.view(-1, src.shape[-1]) + + qx, sx, qx_t, sx_t, global_amax_row, global_amax_col = self._quantize(src) + + # Update the destination with new data + dst.data = qx + dst.scale = sx + dst.data_t = qx_t + dst.scale_t = sx_t + dst.global_amax_row = global_amax_row + dst.global_amax_col = global_amax_col + dst.dtype = src.dtype + dst.quant_dtype = self.dtype + dst.original_shape = original_shape + + return dst + + @property + def supports_allgather_fp8(self) -> bool: + """Whether the tensor data can be all-gathered with an FP8 all-gather. + + TODO(etsykunov): Confirm docstring is correct. Also, this API + seems too FP8-specific and should be reconsidered. + """ + return False + + def transpose_qresult(self, qresult: QuantizedTensorStorage) -> QuantizedTensorStorage: + """Convert row-wise data to column-wise data (?) + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Transpose qresult is not implemented for FP4.") + + @property + def supports_dequantize(self) -> bool: + """Whether quantized tensor can converted to high-precision tensor""" + return False + + @property + def is_data_t_transposed_in_memory(self) -> bool: + """Whether column-wise data is stored in transposed layout. + + TODO(etsykunov): Confirm docstring is correct. + """ + raise NotImplementedError("Not implemented yet") + + def qgemm( + self, + qx: torch.Tensor, + qw: torch.Tensor, + m_params: quantization.MMParams, # pylint: disable=unused-argument + out_dtype: torch.dtype, + sx: torch.Tensor, + sw: torch.Tensor, + bias: torch.Tensor | None = None, + out: torch.Tensor | None = None, + accumulate: bool = False, + gemm_type: quantization.GEMMType = quantization.GEMMType.FPROP, + qresult_x: QuantizedTensorStorage | None = None, + qresult_w: QuantizedTensorStorage | None = None, + ) -> torch.Tensor: + """Python implementation of microblock FP4 GEMM.""" + assert bias is None, "Bias is implemented for FP4 GEMM." + + high_precision_x = cast_from_fp4x2(qx, out_dtype) + high_precision_w = cast_from_fp4x2(qw, out_dtype) + + if self.pow_2_scales: + + if sx.dtype == torch.uint8: + # if scaling factor is stored in uint8 container + sx = torch.tensor(2.0, device=sx.device, dtype=torch.float32) ** ( + ( + sx.to(torch.float32) + - torch.tensor(127, device=sx.device, dtype=torch.float32) + ) + ) + sw = torch.tensor(2.0, device=sw.device, dtype=torch.float32) ** ( + ( + sw.to(torch.float32) + - torch.tensor(127, device=sw.device, dtype=torch.float32) + ) + ) + else: + # if scaling factor is torch.float8_e8m0fnu + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + alpha = torch.tensor(1.0, device=high_precision_x.device, dtype=torch.float32) + + else: + + assert qresult_x is not None + assert qresult_w is not None + + assert qresult_x.global_amax_row is not None + assert qresult_w.global_amax_col is not None + + sx = sx.to(torch.float32) + sw = sw.to(torch.float32) + + factor = 6.0 * 6.0 * 448.0 * 448.0 + + if gemm_type == quantization.GEMMType.WGRAD: + partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col + else: + partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row + alpha = torch.div(partial_alpha, factor).squeeze(-1) + + M, K = high_precision_x.shape + N, K_w = high_precision_w.shape + assert K == K_w, "K dimension mismatch between qx and qw" + + assert K % 32 == 0, "K dimension must be divisible by 32" + assert N % 8 == 0, "N dimension must be divisible by 8" + + block_length = 32 if self.pow_2_scales else 16 + + grid_k = K // block_length + + assert sx.shape == ( + M, + K // block_length, + ), f"sx shape mismatch: expected ({M}, {K//block_length}), got {sx.shape}" + assert sw.shape == ( + N, + K // block_length, + ), f"sw shape mismatch: expected ({N}, {K//block_length}), got {sw.shape}" + + y = torch.zeros(M, N, dtype=torch.float32, device=qx.device) + + # below implementation is to match the FP4 tensor core implementation + # Each output element (i, j) is fp32 accumulation of (K // block_length) inner products + # Each inner product is sx * sw * (1, block_length) x (block_length, 1) with precision in fp32 + # Then batch the computation in M, N dimension + for k in range(grid_k): + k_start = k * block_length + k_end = k_start + block_length + + qx_block = high_precision_x[:, k_start:k_end].clone().contiguous() + qw_block = high_precision_w[:, k_start:k_end].clone().contiguous() + + # Extract scaling factors for the current blocks + sx_block = sx[:, k] + sw_block = sw[:, k] + + y += torch.outer(sx_block, sw_block) * high_precision_gemm_ref( + qx_block, qw_block, torch.float32, is_b_transposed=True + ) + + if not self.pow_2_scales and K > 0: + # only apply global scale for NVFP4 and non-empty cases + y = alpha * y + + # accumulation happens at epilogue in float32 + if accumulate: + assert out is not None, "Output tensor must be provided for accumulation." + y += out.to(torch.float32) + else: + assert out is None, "Output tensor should be None when accumulate is False." + + y = y.to(out_dtype) + return y diff --git a/transformer_engine/pytorch/custom_recipes/utils.py b/transformer_engine/pytorch/custom_recipes/utils.py new file mode 100644 index 000000000..20dc6f11b --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/utils.py @@ -0,0 +1,30 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Utility functions for experimental middleware between Transformer Engine and Kitchen.""" + +import enum + +import torch + + +HIGH_PRECISION_FLOAT_DTYPES = ( + torch.float, + torch.float16, + torch.bfloat16, + torch.float32, +) + + +class Fp4Formats(enum.Enum): + """FP4 data format""" + + E2M1 = "e2m1" + + +def roundup_div(x: int, y: int) -> int: + """Round up division""" + assert x >= 0 + assert y > 0 + return (x + y - 1) // y diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index e809528da..04ffa324d 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -31,21 +31,25 @@ import transformer_engine_torch as tex +from transformer_engine.pytorch.triton.pad import pad_columnwise_scale_inv from . import torch_version from .utils import ( is_non_tn_fp8_gemm_supported, safely_set_viewless_tensor_data, needs_quantized_gemm, ) + from .constants import dist_group_type -from .fp8 import FP8GlobalStateManager, fp8_autocast +from .quantization import FP8GlobalStateManager, autocast from .tensor.float8_tensor import Float8Quantizer, Float8Tensor, Float8CurrentScalingQuantizer from .tensor.mxfp8_tensor import MXFP8Quantizer +from .tensor.nvfp4_tensor import NVFP4Quantizer from .tensor.float8_blockwise_tensor import Float8BlockQuantizer -from .tensor.quantized_tensor import QuantizedTensor, Quantizer -from .tensor._internal.float8_tensor_base import Float8TensorBase -from .tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from .tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from .quantized_tensor import QuantizedTensorStorage, QuantizedTensor, Quantizer +from .tensor.storage.float8_tensor_storage import Float8TensorStorage +from .tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage +from .tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ..debug.pytorch.debug_quantization import DebugQuantizedTensor, DebugQuantizer @@ -418,8 +422,8 @@ def backward( detached_inputs = detach_variable(inputs) with torch.enable_grad(), ctx.recompute_ctx, ctx.torch_gpu_amp_ctx, ctx.torch_cpu_amp_ctx, activation_recompute_forward( activation_recompute=True, recompute_phase=True - ), fp8_autocast( - enabled=ctx.fp8, fp8_recipe=ctx.fp8_recipe + ), autocast( + enabled=ctx.fp8, recipe=ctx.fp8_recipe ): outputs = ctx.run_function(*detached_inputs, **ctx.kwargs) @@ -753,8 +757,8 @@ def checkpoint( def recompute_fn(*args, **kwargs): with torch.autograd.enable_grad(), ( te_recompute_ctx - ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, fp8_autocast( - enabled=fp8, fp8_recipe=fp8_recipe + ), user_recompute_ctx, torch_gpu_amp_forward_ctx, torch_cpu_amp_forward_ctx, autocast( + enabled=fp8, recipe=fp8_recipe ): function(*args, **kwargs) @@ -908,7 +912,7 @@ def _all_gather_fp8( async_op: bool = False, quantizer: Optional[Quantizer] = None, out_shape: Optional[list[int]] = None, -) -> tuple[Float8TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[Float8TensorStorage, Optional[torch.distributed.Work]]: """All-gather FP8 tensor along first dimension.""" world_size = get_distributed_world_size(process_group) @@ -926,7 +930,7 @@ def _all_gather_fp8( # Cast input tensor to FP8 if needed # Note: We cannot directly all-gather the transposed FP8 tensor, # so temporarily modify quantizer to avoid creating FP8 transpose. - if not isinstance(inp, Float8TensorBase): + if not isinstance(inp, Float8TensorStorage): assert isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)) # we cannot directly gather the transposed fp8 tensor # so we need to disable columnwise usage for the quantizer @@ -941,14 +945,20 @@ def _all_gather_fp8( ) # Construct output tensor - out: Float8TensorBase + out: Float8TensorStorage if quantizer is not None: dtype = torch.float32 device = "cuda" if isinstance(inp, Float8Tensor): dtype = inp.dtype device = inp.device + # Temporarily ensure rowwise usage for output tensor creation + # since we're gathering rowwise data, not the transpose + init_rowwise_usage = quantizer.rowwise_usage + init_columnwise_usage = quantizer.columnwise_usage + quantizer.set_usage(rowwise=True, columnwise=init_columnwise_usage) out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + quantizer.set_usage(rowwise=init_rowwise_usage, columnwise=init_columnwise_usage) elif isinstance(inp, Float8Tensor): out = inp.make_like(inp, shape=out_shape) out._data = torch.empty( @@ -959,7 +969,7 @@ def _all_gather_fp8( out._transpose = None out._transpose_invalid = True else: - raise RuntimeError("FP8TensorBase is not supported yet without Quantizer") + raise RuntimeError("Float8TensorStorage is not supported yet without Quantizer") # Assume scaling factors are identical across ranks out._scale_inv = inp._scale_inv @@ -1004,10 +1014,10 @@ def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: def _post_process_fp8_blockwise_gather( - out: Float8BlockwiseQTensorBase, + out: Float8BlockwiseQTensorStorage, quantizer: Float8BlockQuantizer, handle: Optional[torch.distributed.Work] = None, -) -> Float8BlockwiseQTensorBase: +) -> Float8BlockwiseQTensorStorage: """Post-process FP8 blockwise gather.""" if handle is not None: handle.wait() @@ -1016,12 +1026,8 @@ def _post_process_fp8_blockwise_gather( if out._is_gemm_ready_format(): return out - needs_columnwise_data_transpose = ( - quantizer is not None and quantizer.columnwise_usage and not is_non_tn_fp8_gemm_supported() - ) - need_rowwise_scale_transpose = ( - quantizer is not None and quantizer.rowwise_usage and not is_non_tn_fp8_gemm_supported() - ) + needs_columnwise_data_transpose = quantizer is not None and quantizer.columnwise_usage + need_rowwise_scale_transpose = quantizer is not None and quantizer.rowwise_usage # CuBLAS requires transpose of the scale inv tensor, suppose orig input is 256x1024 # columnwise compact format means doing 128x1 quantization of it @@ -1041,7 +1047,7 @@ def _post_process_fp8_blockwise_gather( class _FP8BlockwiseAllGatherAsyncHandle: """Handle for asynchronous FP8 blockwise all-gather.""" - tensor: Float8BlockwiseQTensorBase + tensor: Float8BlockwiseQTensorStorage quantizer: Float8BlockQuantizer async_handle: torch.distributed.Work _synchronized: bool = False @@ -1079,18 +1085,18 @@ def _all_gather_fp8_blockwise( if isinstance(inp, torch.Tensor): device = inp.device dtype = inp.dtype - elif isinstance(inp, Float8BlockwiseQTensorBase): + elif isinstance(inp, Float8BlockwiseQTensorStorage): if inp._rowwise_data is not None: device = inp._rowwise_data.device elif inp._columnwise_data is not None: device = inp._columnwise_data.device else: - raise ValueError("Got Float8BlockwiseQTensorBase input tensor without any data") + raise ValueError("Got Float8BlockwiseQTensorStorage input tensor without any data") dtype = torch.bfloat16 # Only has fp8 dtype. Guess BF16 for dequant. else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or Float8BlockwiseQTensorBase, " - f"found {inp.__class__.__name__})" + "Invalid type for input tensor (expected torch.Tensor or" + f" Float8BlockwiseQTensorStorage, found {inp.__class__.__name__})" ) world_size = get_distributed_world_size(process_group) @@ -1107,7 +1113,7 @@ def _all_gather_fp8_blockwise( # Doing BF16 gather for now as baseline because it's simpler if ( - not isinstance(inp, Float8BlockwiseQTensorBase) + not isinstance(inp, Float8BlockwiseQTensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1132,7 +1138,7 @@ def _all_gather_fp8_blockwise( # Set to compact usage in case the quantizer is not correctly configured orig_all_gather_usage = quantizer.all_gather_usage quantizer.all_gather_usage = True - if not isinstance(inp, Float8BlockwiseQTensorBase): + if not isinstance(inp, Float8BlockwiseQTensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1208,6 +1214,245 @@ def _all_gather_fp8_blockwise( return out, handle +def _swap_first_dims(tensor: torch.Tensor, world_size: int): + """ + Swap first 2 dimensions of a tensor to fix interleaved + data format after gathering transposed data. + + For more than 2 dimensions, we squash the trailing dimensions, + instead of the first few dimensions, that's because the shape + passed in this function is already transposed. + """ + + shape = tensor.shape + assert tensor.ndim >= 2, "Wrong number of dimensions for fixing interleave." + first_dim = shape[0] + flattened_trailing = math.prod(shape[1:]) + assert first_dim % world_size == 0, "Wrong dimensions for fixing interleave." + tensor = tensor.reshape(world_size, first_dim // world_size, flattened_trailing) + tensor = tex.swap_first_dims(tensor, out=None) + return tensor.reshape(first_dim // world_size, flattened_trailing * world_size) + + +def _post_process_nvfp4_gather( + out: NVFP4TensorStorage, + columnwise_data_interleaved: torch.Tensor, + columnwise_scale_inv_interleaved: torch.Tensor, + world_size: int, + handle: Optional[torch.distributed.Work] = None, +) -> NVFP4TensorStorage: + """Post-process FP8 blockwise gather.""" + if handle is not None: + handle.wait() + handle = None + + # Fix the interleaved transposed data from gathering along first dim. + out._columnwise_scale_inv = _swap_first_dims(columnwise_scale_inv_interleaved, world_size) + out._columnwise_data = _swap_first_dims(columnwise_data_interleaved, world_size) + + # Optionally pad the scaling inverse if needed. + out._columnwise_scale_inv = pad_columnwise_scale_inv(out._columnwise_scale_inv) + + +@dataclass +class _NVFP4AllGatherAsyncHandle: + """Handle for asynchronous NVFP4 all-gather.""" + + output: NVFP4TensorStorage + columnwise_data_interleaved: torch.Tensor + columnwise_scale_inv_interleaved: torch.Tensor + world_size: int + async_handle: torch.distributed.Work + _synchronized: bool = False + + def wait(self) -> None: + """Wait for the async operation to complete and post-process the tensor.""" + if self._synchronized: + return + self.async_handle.wait() + _post_process_nvfp4_gather( + self.output, + self.columnwise_data_interleaved, + self.columnwise_scale_inv_interleaved, + self.world_size, + ) + self._synchronized = True + + +def _all_gather_nvfp4( + inp: torch.Tensor, + process_group: dist_group_type, + *, + async_op: bool = False, + quantizer: NVFP4Quantizer, + out_shape: Optional[list[int]] = None, +) -> tuple[NVFP4TensorStorage, Optional[torch.distributed.Work]]: + """All-gather NVFP4 tensor along first dimension.""" + + # Input tensor attributes + in_shape: Iterable[int] = None + in_shape_t: Iterable[int] = None + device: torch.device + dtype: torch.dtype + + # Construct packed shapes for input and input_t. + if isinstance(inp, torch.Tensor) and not isinstance(inp, NVFP4TensorStorage): + # High-precision tensor. + in_shape = NVFP4Quantizer.convert_shape_for_fp4(inp.size()) + in_shape_t = NVFP4Quantizer.convert_shape_for_fp4( + NVFP4Quantizer.get_columnwise_shape(inp.size()) + ) + device = inp.device + dtype = inp.dtype + elif isinstance(inp, NVFP4TensorStorage): + if inp._rowwise_data is not None: + in_shape = inp._rowwise_data.size() + device = inp._rowwise_data.device + if inp._columnwise_data is not None: + in_shape_t = inp._columnwise_data.size() + device = inp._columnwise_data.device + dtype = torch.bfloat16 + else: + raise ValueError( + "Invalid type for input tensor (expected torch.Tensor or NVFP4TensorStorage, " + f"found {inp.__class__.__name__})" + ) + + assert in_shape is not None or in_shape_t is not None, "No data found." + + world_size = get_distributed_world_size(process_group) + + if out_shape is None: + out_shape = [in_shape[0] * world_size] + in_shape[1:] + + # For cases where inp has dimensions that cannot be quantized, + # we gather in high precision followed by a cast to NVFP4. + if ( + not isinstance(inp, NVFP4TensorStorage) + and quantizer is not None + and not quantizer.is_quantizable(inp) + ): + out = torch.empty( + out_shape, + dtype=dtype, + device=device, + memory_format=torch.contiguous_format, + ) + torch.distributed.all_gather_into_tensor(out, inp, group=process_group) + out = quantizer(out) + return out, None + + # Cast input tensor to NVFP4 with required data + if not isinstance(inp, NVFP4TensorStorage): + inp = quantizer(inp) + elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( + quantizer.columnwise_usage and inp._columnwise_data is None + ): + warnings.warn( + "Input and quantizer do not have matching usages. " + "Dequantizing and requantizing to NVFP4." + ) + inp = quantizer(inp.dequantize()) + + # Construct NVFP4 output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + + # Coalesce NCCL collectives for gathering data and scale inverses. + with torch.distributed._coalescing_manager( + group=process_group, + device=device, + async_ops=async_op, + ) as gather_coalescing_manager: + + # Gather NVFP4 data for row-wise usage + if quantizer.rowwise_usage: + + # Remove padding from NVFP4 scale-inverses + assert in_shape is not None, "Shape not found." + in_scale_inv = inp._rowwise_scale_inv + out_scale_inv = out._rowwise_scale_inv + flattened_in_shape0 = math.prod(in_shape[:-1]) + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] + + # Launch all-gathers + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + torch.distributed.all_gather_into_tensor( + out._rowwise_data, + inp._rowwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_rowwise = inp._amax_rowwise + + # Gather the transposed NVFP4 data along first dimension. Fix format later. + if quantizer.columnwise_usage: + + # Remove padding from NVFP4 scale-inverses + # For doing an all-gather on transposed scale inverses, + # we need to remove padding from both dimension. + in_scale_inv = inp._columnwise_scale_inv + # take caution that for in_shape_t, flatten in the trailing dimensions! + flattened_in_shape0 = in_shape_t[0] + flattened_in_shape1 = math.prod(in_shape_t[1:]) + + # Remove dim0 padding + if in_scale_inv.size(0) != flattened_in_shape0: + in_scale_inv = in_scale_inv[:flattened_in_shape0] + + # Remove dim1 padding (pack first). + unpadded_dim1 = flattened_in_shape1 * 2 // 16 + if in_scale_inv.size(1) != unpadded_dim1: + in_scale_inv = in_scale_inv[:, :unpadded_dim1].contiguous() + + # Construct tensor to gather transposed scale_inv (interleaved) and launch AG. + out_scale_inv = torch.empty( + [flattened_in_shape0 * world_size] + [in_scale_inv.shape[1]], + dtype=in_scale_inv.dtype, + layout=in_scale_inv.layout, + device=in_scale_inv.device, + ) + torch.distributed.all_gather_into_tensor( + out_scale_inv, + in_scale_inv, + group=process_group, + ) + + # Construct tensor to gather transposed data (interleaved) and launch AG. + out_columnwise_data = torch.empty( + [inp._columnwise_data.shape[0] * world_size] + list(inp._columnwise_data.shape[1:]), + dtype=inp._columnwise_data.dtype, + layout=inp._columnwise_data.layout, + device=inp._columnwise_data.device, + ) + torch.distributed.all_gather_into_tensor( + out_columnwise_data, + inp._columnwise_data, + group=process_group, + ) + + # Transfer amax to output. + out._amax_columnwise = inp._amax_columnwise + + handle = gather_coalescing_manager if async_op else None + + # Fixes interleaved data for transposed tensor/scale inv and pads scale inv if needed. + if async_op and quantizer.columnwise_usage: + handle = _NVFP4AllGatherAsyncHandle( + out, out_columnwise_data, out_scale_inv, world_size, handle + ) + elif quantizer.columnwise_usage: + _post_process_nvfp4_gather(out, out_columnwise_data, out_scale_inv, world_size, handle) + + return out, handle + + def _all_gather_mxfp8( inp: torch.Tensor, process_group: dist_group_type, @@ -1215,7 +1460,7 @@ def _all_gather_mxfp8( async_op: bool = False, quantizer: MXFP8Quantizer, out_shape: Optional[list[int]] = None, -) -> tuple[MXFP8TensorBase, Optional[torch.distributed.Work]]: +) -> tuple[MXFP8TensorStorage, Optional[torch.distributed.Work]]: """All-gather MXFP8 tensor along first dimension.""" # Input tensor attributes @@ -1226,7 +1471,7 @@ def _all_gather_mxfp8( in_shape = inp.size() device = inp.device dtype = inp.dtype - elif isinstance(inp, MXFP8TensorBase): + elif isinstance(inp, MXFP8TensorStorage): if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device @@ -1238,7 +1483,7 @@ def _all_gather_mxfp8( dtype = torch.bfloat16 # Guess high-precision dtype. else: raise ValueError( - "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " + "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorStorage, " f"found {inp.__class__.__name__})" ) @@ -1250,7 +1495,7 @@ def _all_gather_mxfp8( # For cases where inp has dimensions that cannot be quantized, # we gather in high precision followed by a cast to FP8. if ( - not isinstance(inp, MXFP8TensorBase) + not isinstance(inp, MXFP8TensorStorage) and quantizer is not None and not quantizer.is_quantizable(inp) ): @@ -1265,7 +1510,7 @@ def _all_gather_mxfp8( return out, None # Cast input tensor to MXFP8 with required data - if not isinstance(inp, MXFP8TensorBase): + if not isinstance(inp, MXFP8TensorStorage): inp = quantizer(inp) elif (quantizer.rowwise_usage and inp._rowwise_data is None) or ( quantizer.columnwise_usage and inp._columnwise_data is None @@ -1295,7 +1540,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1319,7 +1563,6 @@ def _all_gather_mxfp8( flattened_in_shape0 = math.prod(in_shape[:-1]) // 32 if in_scale_inv.size(0) != flattened_in_shape0: in_scale_inv = in_scale_inv[:flattened_in_shape0] - out_scale_inv[flattened_in_shape0 * world_size :].zero_() out_scale_inv = out_scale_inv[: flattened_in_shape0 * world_size] # Launch all-gathers @@ -1351,7 +1594,7 @@ def gather_along_first_dim( # Return immediately if no communication is required world_size = get_distributed_world_size(process_group) if world_size == 1: - if quantizer is not None and not isinstance(inp, QuantizedTensor): + if quantizer is not None and not isinstance(inp, QuantizedTensorStorage): inp = quantizer(inp) return inp, None @@ -1398,7 +1641,7 @@ def gather_along_first_dim( out_shape[0] *= world_size # FP8 case: delayed scaling or current scaling - if isinstance(inp, Float8TensorBase) or isinstance( + if isinstance(inp, Float8TensorStorage) or isinstance( quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): return _all_gather_fp8( @@ -1410,7 +1653,9 @@ def gather_along_first_dim( ) # FP8 block scaling case, block length = 128 - if isinstance(inp, Float8BlockwiseQTensorBase) or isinstance(quantizer, Float8BlockQuantizer): + if isinstance(inp, Float8BlockwiseQTensorStorage) or isinstance( + quantizer, Float8BlockQuantizer + ): return _all_gather_fp8_blockwise( inp, process_group, @@ -1420,7 +1665,7 @@ def gather_along_first_dim( ) # MXFP8 case - if isinstance(inp, MXFP8TensorBase) or isinstance(quantizer, MXFP8Quantizer): + if isinstance(inp, MXFP8TensorStorage) or isinstance(quantizer, MXFP8Quantizer): assert isinstance(quantizer, MXFP8Quantizer) return _all_gather_mxfp8( inp, @@ -1430,13 +1675,24 @@ def gather_along_first_dim( out_shape=out_shape, ) + # NVFP4 case + if isinstance(inp, NVFP4TensorStorage) or isinstance(quantizer, NVFP4Quantizer): + assert isinstance(quantizer, NVFP4Quantizer) + return _all_gather_nvfp4( + inp, + process_group, + async_op=async_op, + quantizer=quantizer, + out_shape=out_shape, + ) + # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." ) - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorStorage): inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format @@ -1454,7 +1710,7 @@ def gather_along_first_dim( return out, None # Dequantize quantized tensor if not supported - if isinstance(inp, QuantizedTensor): + if isinstance(inp, QuantizedTensorStorage): warnings.warn( "Attempting to all-gather an unsupported quantized tensor. " "Falling back to high-precision all-gather." @@ -1640,6 +1896,43 @@ def allreduce( return inp, handle +def _get_module_fsdp_state(module): + """ + If module is an FSDP module, return its _FSDPState. + Otherwise, return the _FSDPState of the closest parent FSDP module + in the module hierarchy the module belongs to. + """ + + if hasattr(module, "_get_fsdp_state"): + # this will return correct fsdp state if module itself is an fsdp module + fsdp_state = module._get_fsdp_state() + elif getattr(module, "_te_cached_parent_fsdp_state", None) is not None: + # See if we have cached the parent fsdp state of the module + fsdp_state = module._te_cached_parent_fsdp_state + else: + from torch.distributed._composable_state import _module_state_mapping + + # Otherwise get the fsdp state of lca of module in the module hierarchy + min_nodes_in_parent = float("inf") + closest_parent_fsdp_mod = None + for fsdp_mod in _module_state_mapping.keys(): + all_submodules = list(fsdp_mod.modules()) + for submodule in all_submodules: + if submodule is module: + if min_nodes_in_parent > len(all_submodules): + closest_parent_fsdp_mod = fsdp_mod + min_nodes_in_parent = len(all_submodules) + if closest_parent_fsdp_mod is None: + raise RuntimeError( + "Module is not FSDP-wrapped and does not have any FSDP-wrapped parent modules." + ) + fsdp_state = closest_parent_fsdp_mod._get_fsdp_state() + # Cache the parent fsdp state of the module to avoid recomputing + # the closest parent fsdp module. + module._te_cached_parent_fsdp_state = fsdp_state + return fsdp_state + + def _fsdp_scatter_tensors( fsdp_group: dist_group_type, *tensors: torch.Tensor, @@ -1724,7 +2017,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: if hasattr(fsdp_root, "primary_weights_in_fp8"): assert not fsdp_root.primary_weights_in_fp8, ( "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.fp8_model_init(...) context." + "Please initialize your model without the te.quantized_model_init(...) context." ) root_state = _get_module_fsdp_state(fsdp_root) assert root_state is not None, "Root module does not have a valid _FSDPState." @@ -1737,7 +2030,7 @@ def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None: if hasattr(fsdp_module.module, "primary_weights_in_fp8"): assert not fsdp_module.module.primary_weights_in_fp8, ( "TE modules with primary weights in FP8 cannot be FSDP-wrapped. " - "Please initialize your model without the te.fp8_model_init(...) context." + "Please initialize your model without the te.quantized_model_init(...) context." ) setattr(fsdp_module.module, "fsdp_group", state.process_group) diff --git a/transformer_engine/pytorch/float8_tensor.py b/transformer_engine/pytorch/float8_tensor.py index a771e3bb7..eeafc23c7 100644 --- a/transformer_engine/pytorch/float8_tensor.py +++ b/transformer_engine/pytorch/float8_tensor.py @@ -4,6 +4,16 @@ """Tensor class with FP8 data""" +import warnings + from .tensor.float8_tensor import Float8Tensor +warnings.warn( + "transformer_engine.pytorch.float8_tensor is deprecated and will be removed" + " in a future release. Float8Tensor should be imported directly through " + "`from transformer_engine.pytorch import Float8Tensor`", + DeprecationWarning, + stacklevel=2, +) + __all__ = ["Float8Tensor"] diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index 15cb88b00..f937b3de9 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -1,22 +1,27 @@ -# This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -"""FP8 utilities for TransformerEngine""" -from __future__ import annotations +""" +DEPRECATED in favor of `transformer_engine.pytorch.quantization.py`. +""" -import abc -import itertools -import os -from contextlib import contextmanager -from collections import deque -from typing import Callable, List, Optional, Dict, Any, Tuple, Union +# pylint: disable=wrong-import-position,unused-import -import torch -from torch.utils.cpp_extension import IS_HIP_EXTENSION -import transformer_engine_torch as tex +import warnings + +warnings.warn( + "Using deprecated internal API from Transformer Engine. " + "transformer_engine.pytorch.fp8 will be removed in a " + "future release.", + DeprecationWarning, + stacklevel=2, +) + + +# There are some users indirectly importing these classes +# from fp8.py. This ensure backwards compatibility. +# https://github.com/Lightning-AI/lightning-thunder/pull/2635. from transformer_engine.common.recipe import ( Recipe, DelayedScaling, @@ -24,1087 +29,40 @@ MXFP8BlockScaling, Float8CurrentScaling, Float8BlockScaling, + NVFP4BlockScaling, + CustomRecipe, ) -from .constants import dist_group_type -from .utils import get_device_compute_capability, get_torch_float8_e4m3_type, get_torch_float8_e5m2_type -from .jit import jit_fuser - -__all__ = ["fp8_autocast", "fp8_model_init"] - -def check_fp8_support() -> Tuple[bool, str]: - if IS_HIP_EXTENSION: - gpu_arch = get_device_compute_capability() - if gpu_arch in ((9, 4), (9, 5)): - return True, "" - else: - return False, "Device arch gfx94x or gfx95x required for FP8 execution." - else: - """Return if fp8 support is available""" - if get_device_compute_capability() >= (9, 0): # hopper and above - return True, "" - if get_device_compute_capability() < (8, 9): # pre-ada - return False, "Device compute capability 8.9 or higher required for FP8 execution." - if tex.get_cublasLt_version() < 120103: - return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." - if float(torch.version.cuda) < 12.1: - return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." - return True, "" - -def check_mxfp8_support() -> Tuple[bool, str]: - """Return if fp8 support is available""" - if IS_HIP_EXTENSION: - if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") == "0": - return False, "MXFP8 support is not enabled." - gpu_arch = get_device_compute_capability() - if gpu_arch == (9, 5): - return True, "" - return False, "Gfx95x is required for MXFP8 execution." - if get_device_compute_capability() >= (12, 0): - return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." - if get_device_compute_capability() >= (10, 0): # blackwell and above - return True, "" - return False, "Device compute capability 10.0 or higher required for MXFP8 execution." - - -def check_fp8_block_scaling_support() -> Tuple[bool, str]: - """Return if fp8 block scaling support is available""" - if IS_HIP_EXTENSION: - return False, "FP8 block scaled gemm not yet supported for ROCm" - if ( - get_device_compute_capability() >= (9, 0) - and get_device_compute_capability() < (10, 0) - and float(torch.version.cuda) >= 12.9 - ): - return True, "" - return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." - - -def check_recipe_support(recipe: Recipe) -> None: - """Check if the given recipe is supported.""" - recipe_supported = True - unsupported_reason = "" - if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): - recipe_supported, unsupported_reason = check_fp8_support() - elif isinstance(recipe, Float8BlockScaling): - recipe_supported, unsupported_reason = check_fp8_block_scaling_support() - elif isinstance(recipe, MXFP8BlockScaling): - recipe_supported, unsupported_reason = check_mxfp8_support() - assert recipe_supported, unsupported_reason - - -def get_default_fp8_recipe() -> Recipe: - """FP8 recipe with default args.""" - if IS_HIP_EXTENSION: - if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") != "2": - return DelayedScaling() - gpu_arch = get_device_compute_capability() - if gpu_arch == (9, 5): - return MXFP8BlockScaling() - return DelayedScaling() - if check_mxfp8_support()[0]: - return MXFP8BlockScaling() - if get_device_compute_capability() >= (12, 0): - # This is a temporary restriction until MXFP8 is supported for all gemm layouts. - return Float8CurrentScaling() - return DelayedScaling() - - -def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return get_torch_float8_e4m3_type() - return get_torch_float8_e5m2_type() - - -def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get fp8 data type according to recipe and tensor""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return tex.DType.kFloat8E4M3 - return tex.DType.kFloat8E5M2 - - -def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: - """Get max representible FP8 value.""" - if fp8_recipe.fp8_format == Format.E4M3 or ( - fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor - ): - return Format.E4M3.value.max_fwd - return Format.E5M2.value.max_fwd - - -class FP8GlobalStateManager: - """Class to keep track of and manipulate the global - FP8 state at different stages of execution. - """ - - FP8_ENABLED = False - FP8_CALIBRATION = False - FP8_RECIPE = None - FP8_DISTRIBUTED_GROUP = None - FP8_PARAMETERS = False - HIGH_PRECISION_INIT_VAL = False - IS_FIRST_FP8_MODULE = False - FP8_GRAPH_CAPTURING = False - SKIP_FP8_REDUCTION_FOR_FSDP2 = False - FP8_AUTOCAST_DEPTH = 0 - global_amax_buffer = {} - global_amax_history_buffer = {} - global_scale_buffer = {} - fp8_tensors_recompute_buffer = [] - fp8_available = None - reason_for_no_fp8 = "" - autocast_arguments = {} - autocast_to_fp8_params = {} - fp8_param_to_autocast = {} - skip_fp8_weight_update_tensor = None - mxfp8_available = None - reason_for_no_mxfp8 = "" - fp8_block_scaling_available = None - reason_for_no_fp8_block_scaling = None - - @classmethod - def reset(cls) -> None: - """Reset the global state""" - cls.FP8_ENABLED = False - cls.FP8_CALIBRATION = False - cls.FP8_RECIPE = None - cls.FP8_DISTRIBUTED_GROUP = None - cls.FP8_PARAMETERS = False - cls.HIGH_PRECISION_INIT_VAL = False - cls.IS_FIRST_FP8_MODULE = False - cls.FP8_GRAPH_CAPTURING = False - cls.FP8_AUTOCAST_DEPTH = 0 - cls.global_amax_buffer = {} - cls.global_amax_history_buffer = {} - cls.global_scale_buffer = {} - cls.fp8_tensors_recompute_buffer = [] - cls.fp8_available = None - cls.reason_for_no_fp8 = "" - cls.autocast_arguments = {} - cls.autocast_to_fp8_params = {} - cls.fp8_param_to_autocast = {} - cls.skip_fp8_weight_update_tensor = None - cls.mxfp8_available = None - cls.reason_for_no_mxfp8 = "" - cls.fp8_block_scaling_available = None - cls.reason_for_no_fp8_block_scaling = "" - - @classmethod - def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: - """`skip_fp8_weight_update_tensor` inplace setter.""" - if cls.skip_fp8_weight_update_tensor is None: - cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") - cls.skip_fp8_weight_update_tensor.fill_(skip) - - @classmethod - def get_skip_fp8_weight_update_tensor(cls) -> None: - """`skip_fp8_weight_update_tensor` getter.""" - return cls.skip_fp8_weight_update_tensor - - @classmethod - def is_fp8_available(cls) -> Tuple[bool, str]: - """Return if fp8 support is available""" - if cls.fp8_available is None: - cls.fp8_available, cls.reason_for_no_fp8 = check_fp8_support() - return cls.fp8_available, cls.reason_for_no_fp8 - - @classmethod - def is_mxfp8_available(cls) -> Tuple[bool, str]: - """Return if MXFP8/current scaling support is available.""" - if cls.mxfp8_available is None: - cls.mxfp8_available, cls.reason_for_no_mxfp8 = check_mxfp8_support() - return cls.mxfp8_available, cls.reason_for_no_mxfp8 - - @classmethod - def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: - """Return if Float8 block scaling support is available.""" - if cls.fp8_block_scaling_available is None: - cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling = ( - check_fp8_block_scaling_support() - ) - return cls.fp8_block_scaling_available, cls.reason_for_no_fp8_block_scaling - - @staticmethod - def get_meta_tensor_key(forward: bool = True) -> str: - """Returns scaling key in `fp8_meta`.""" - if forward: - return "scaling_fwd" - return "scaling_bwd" - - @staticmethod - def get_fwd_bwd_key(forward: bool = True) -> str: - """Convert bool `forward` to string.""" - return "forward" if forward else "backward" - - @classmethod - def get_buffer_info(cls) -> str: - """ - Returns a key for `fp8_meta` that stores the module's index - in the global buffers along with autocast information. - """ - return "buffer_index_and_autocast_key" - - @classmethod - def get_key_in_buffer( - cls, - forward: bool, - fp8_recipe: Recipe, - fp8_group: dist_group_type, - ) -> str: - """Returns a key into the global FP8 buffers.""" - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - fwd_bwd_key = cls.get_fwd_bwd_key(forward) - return f"{fwd_bwd_key}_{autocast_key}" - - @classmethod - def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: - """Splits buffer key into relevant parts.""" - forward, autocast_key = key.split("_", 1) - forward = forward == "forward" - return forward, autocast_key - - @classmethod - def add_fp8_tensors_to_global_buffer( - cls, - fp8_meta: Dict[str, Any], - ) -> None: - """ - Delayed scaling only. - - The amax reduction process happens completely outside the FP8 modules. - To participate in the reduction, the only role played by a module is - to call this function in order to append it's FP8 tensor into a global - buffer. There are 5 global buffers maintained, one each for amax, amax - history, scale, scale-inverse, and non-weight-mask. Each buffer has - keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix - to indicate the type of FP8 tensor, since the forward and backward - reductions happen separately. - - Note: For CG capture, this method is called from the graphed - wrapper. For non CG case, it's called from within the module. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Every module must call this function exactly once since - # the amax tensors are static. Ensures that compatibility - # with non-graphed modules is maintained. - index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. - if index_in_buffer in fp8_meta: - return - - fp8_meta[index_in_buffer] = [] - for forward in (True, False): - fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) - if fp8_meta_tensor_key not in fp8_meta: - # Handles non-parameter FP8 modules, e.g. DPA. - continue - - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) - - if key not in cls.global_amax_buffer: - cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] - cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] - cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] - else: - cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) - cls.global_amax_history_buffer[key].append( - fp8_meta[fp8_meta_tensor_key].amax_history - ) - cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) - fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) - fp8_meta[index_in_buffer].append(key) - - @classmethod - def is_fp8_enabled(cls) -> bool: - """Is FP8 enabled""" - return cls.FP8_ENABLED - - @classmethod - def is_fp8_calibration(cls) -> bool: - """Is FP8 calibration""" - return cls.FP8_CALIBRATION - - @classmethod - def with_fp8_parameters(cls) -> bool: - """Should the parameters be stored as FP8""" - return cls.FP8_PARAMETERS - - @classmethod - def with_high_precision_init_val(cls) -> bool: - """Should the high precision initial values be stored with FP8 parameters""" - return cls.HIGH_PRECISION_INIT_VAL - - @classmethod - def fp8_graph_capturing(cls) -> bool: - """Is CUDA graph capture under way?""" - return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() - - @classmethod - def is_first_fp8_module(cls): - """Returns `True` only the first time when called multiple - times from within the same `fp8_autocast` context. - """ - tmp = cls.IS_FIRST_FP8_MODULE - cls.IS_FIRST_FP8_MODULE = False - return tmp - - @classmethod - def get_fp8_recipe(cls) -> Recipe: - """Return the fp8 recipe""" - if cls.FP8_RECIPE is not None: - return cls.FP8_RECIPE - return get_default_fp8_recipe() - - @classmethod - def get_fp8_group(cls) -> Union[dist_group_type, None]: - """Return the fp8 group for scale/amax comm""" - return cls.FP8_DISTRIBUTED_GROUP - - @classmethod - def get_fp8_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: - """FP8 autocast state getter""" - return ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) - - @classmethod - def set_fp8_autocast_state( - cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] - ) -> None: - """FP8 autocast state setter""" - ( - cls.FP8_ENABLED, - cls.FP8_CALIBRATION, - cls.FP8_RECIPE, - cls.FP8_DISTRIBUTED_GROUP, - cls.IS_FIRST_FP8_MODULE, - cls.FP8_GRAPH_CAPTURING, - ) = fp8_state - - @staticmethod - def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: - """Reduce tensor across given group.""" - if torch.distributed.is_initialized(): - torch.distributed.all_reduce( - tensor, - op=torch.distributed.ReduceOp.MAX, - group=group, - async_op=False, - ) - - @classmethod - def reduce_and_update_fp8_tensors( - cls, - forward: bool = True, - ) -> None: - """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" - # global_amax_buffer should only be non-empty for fp8 delayed scaling - for buffer_key, amax_buffer in cls.global_amax_buffer.items(): - # Check for forward or backward reduction. - fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) - if fwd_update != forward: - continue - if len(amax_buffer) == 0: - continue - - # Retrieve autocast specific args and concat amaxes. - recipe, group = cls.autocast_arguments[autocast_key] - contiguous_amax = torch.cat(amax_buffer) - - # Reduction. - if ( - recipe.reduce_amax - and torch.distributed.is_initialized() - and torch.distributed.get_world_size(group=group) > 1 - ): - cls.reduce_tensor_across_group_op_max(contiguous_amax, group) - - # Amax and scale update. - unfused_update = ( - bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) - or callable(recipe.amax_compute_algo) - or callable(recipe.scaling_factor_compute_algo) - ) - - if not unfused_update: - tex.fused_amax_and_scale_update_after_reduction( - contiguous_amax, - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - recipe.amax_compute_algo, - get_fp8_te_dtype(recipe, forward), - recipe.margin, - ) - else: - split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) - - for amax_history, scale in zip( - cls.global_amax_history_buffer[buffer_key], - cls.global_scale_buffer[buffer_key], - ): - _amax_and_scale_update( - amax_history, scale, get_fp8_max(recipe, forward), recipe - ) - - @classmethod - def get_unique_autocast_key( - cls, - recipe: Optional[Recipe] = None, - group: Optional[dist_group_type] = None, - ): - """ - For FP8, each autocast can be uniquely identified by the recipe and fp8 group. - Safely using `hash` as we never cross checkpoint boundaries. - """ - return f"{str(recipe)}:{hash(group)}" - - @classmethod - def fp8_autocast_enter( - cls, - enabled: bool = False, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, - ) -> None: - """Set state and tracking variables for entry into FP8 region.""" - - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe - autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) - cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) - - cls.FP8_ENABLED = enabled - cls.FP8_CALIBRATION = calibrating - cls.FP8_RECIPE = fp8_recipe - cls.FP8_DISTRIBUTED_GROUP = fp8_group - cls.FP8_GRAPH_CAPTURING = _graph - - if cls.FP8_AUTOCAST_DEPTH == 0: - cls.IS_FIRST_FP8_MODULE = True - cls.FP8_AUTOCAST_DEPTH += 1 - - if enabled: - fp8_available, reason_for_no_fp8 = cls.is_fp8_available() - assert fp8_available, reason_for_no_fp8 - if isinstance(fp8_recipe, MXFP8BlockScaling): - mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() - assert mxfp8_available, reason_for_no_mxfp8 - if isinstance(fp8_recipe, Float8BlockScaling): - fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() - assert fp8_block_available, reason_for_no_fp8_block - - @classmethod - def fp8_autocast_exit(cls, enabled: bool, _graph: bool) -> None: - """Set state and tracking variables for exit from FP8 region.""" - cls.FP8_AUTOCAST_DEPTH -= 1 - # Reduce only the non-FP8 weight modules here. - # FP8 weight modules are reduced at the end of the optimizer - # step after the weight amax is populated. - if not cls.SKIP_FP8_REDUCTION_FOR_FSDP2 and enabled and cls.FP8_AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): - # delayed scaling only function, for other recipes (current scaling with any granularity), - # this is noop for other recipes because cls.global_amax_buffer is empty list - cls.reduce_and_update_fp8_tensors(forward=True) - - @classmethod - def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Copy the scaling factors and amaxes for recompute forward phase - to ensure both forward steps are numerically same. - """ - - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - - to_copy = [ - fp8_meta["scaling_fwd"].amax_history.clone(), - fp8_meta["scaling_fwd"].scale.clone(), - ] - - if buffer_position_key in fp8_meta: - cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) - else: - if len(cls.fp8_tensors_recompute_buffer) == 0: - cls.fp8_tensors_recompute_buffer = [deque()] - else: - cls.fp8_tensors_recompute_buffer.append(deque()) - cls.fp8_tensors_recompute_buffer[-1].append(to_copy) - fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 - - @classmethod - def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: - """Switch to the copied scaling factors and amaxes from phase - 1 forward for indentical numerical outputs. - """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - # Store updated amaxes and scales from phase 1 post forward. - fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() - fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() - - # Retrieve stashed amaxes and scales from phase 1 pre forward. - buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" - stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() - - # Replace amaxes and scales with stashed values for phase 2 forward - fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) - fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) - - @staticmethod - def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: - """Restore latest scaling factors and amaxes after recompute forward run.""" - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): - return - - fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) - fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) - - -@contextmanager -def fp8_model_init( - enabled: bool = True, - recipe: Optional[Recipe] = None, - preserve_high_precision_init_val: bool = False, -) -> None: - """ - Context manager for FP8 initialization of parameters. - - Example usage: - - .. code-block:: python - with fp8_model_init(enabled=True): - model = transformer_engine.pytorch.Linear(768, 768) - - # Preserving high precision initial value to initialize master weight - with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): - model = transformer_engine.pytorch.Linear(768, 768) - master_weight = model.weight.get_high_precision_init_val() - model.weight.clear_high_precision_init_val() - - Parameters - ---------- - enabled: bool, default = `True` - when enabled, Transformer Engine modules created inside this `fp8_model_init` - region will hold only FP8 copies of its parameters, as opposed to the default - behavior where both higher precision and FP8 copies are present. Setting this - option to `True` may result in lower memory consumption and is especially - useful for scenarios like: - - * full model training using optimizer with master weights, where the high - precision copies of weights are already present in the optimizer. - * inference, where only the FP8 copies of the parameters are used. - * LoRA-like fine-tuning, where the main parameters of the model do not change. - recipe: transformer_engine.common.recipe.Recipe, default = `None` - Recipe used to create the parameters. If left to None, it uses the default FP8 recipe. - preserve_high_precision_init_val: bool, default = `False` - when enabled, store the high precision tensor used to initialize FP8 parameters - in CPU memory, and add two function attributes named `get_high_precision_init_val()` - and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high - precision tensor. The purpose is that users can use this high-precision copy - to initialize master weights, avoiding the loss of precision that can occur when - using FP8 parameters directly. Note that after the master weights are initialized, - users should call `clear_high_precision_init_val()` to release this CPU memory. - - This functionality is *EXPERIMENTAL*. - """ - _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS - _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE - _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL - FP8GlobalStateManager.FP8_PARAMETERS = enabled - FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val - try: - yield - finally: - FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters - FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe - FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val - - -@contextmanager -def fp8_autocast( - enabled: bool = True, - calibrating: bool = False, - fp8_recipe: Optional[Recipe] = None, - fp8_group: Optional[dist_group_type] = None, - _graph: bool = False, -) -> None: - """ - Context manager for FP8 usage. - - .. code-block:: python - - with fp8_autocast(enabled=True): - out = model(inp) - - .. note:: - - Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors - with shapes where both dimensions are divisible by 16. In terms of the input to the full - Transformer network, this typically requires padding sequence length to be multiple of 16. - - .. note:: - - When :attr:`fp8_recipe.reduce_amax==True`, any module must not be invoked more than once - inside a single `fp8_autocast` region. This is unsupported behavior because the amax - reduction is handled during the exit of the `fp8_autocast` context. Calling the same - module more than once inside an `fp8_autocast` region overrides the amax tensors - before reduction can occur. - - Parameters - ---------- - enabled: bool, default = `True` - whether or not to enable fp8 - calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: recipe.Recipe, default = `None` - recipe used for FP8 training. - fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - """ - if enabled: - check_recipe_support(fp8_recipe) - fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() - FP8GlobalStateManager.fp8_autocast_enter( - enabled=enabled, - calibrating=calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, - _graph=_graph, - ) - try: - yield - finally: - FP8GlobalStateManager.set_fp8_autocast_state(fp8_state) - FP8GlobalStateManager.fp8_autocast_exit(enabled, _graph=_graph) - - -def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: - """Update amax history and set next amax to zero.""" - if amax_history.shape[0] > 1: - new_amax_history = torch.roll(amax_history, -1, 0) - amax_history.copy_(new_amax_history) - amax_history[0].fill_(0.0) - return amax_history - - -@torch.jit.script -def _default_get_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: str, -) -> Tuple[torch.Tensor, torch.Tensor]: - """Default function to obtain amax from history.""" - if amax_compute_algo == "max": - amax = torch.max(amax_history, dim=0).values - else: # amax_compute_algo == "most_recent" - amax = amax_history[0].clone() - - amax_history = _update_amax_history(amax_history) - return amax_history, amax - - -@jit_fuser -def _default_sf_compute( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - margin: int, - _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter -) -> torch.Tensor: - """Default function to convert amax to scaling factor. - Computing the scaling factor requires consideration of the following scenarios: - 1. amax == 0: - No action is possible, set scale to the previous scale (or 1). - 2. 0 < amax < tiny_amax - The amax is too tiny that the scale becomes infinite in FP32. - Set scale = FP32_max - 3. tiny_amax <= amax < FP32_max: - Set scale = FP8_max (or scaled_max) / amax - 4. When amax == inf or amax == nan: - No action is possible, set scale to the previous scale (or 1). - """ - sf = (fp8_max / amax) / (2**margin) - sf = torch.where(amax > 0.0, sf, scale) - sf = torch.where(torch.isfinite(amax), sf, scale) - sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) - scale.copy_(sf) - return scale - - -def _compute_amax_and_update_history( - amax_history: torch.Tensor, - amax_compute_algo: Union[Callable, str], -) -> Tuple[torch.Tensor, torch.Tensor]: - """Obtain the amax from the history.""" - - if callable(amax_compute_algo): - amax = amax_compute_algo(amax_history) - amax_history = _update_amax_history(amax_history) - return amax_history, amax - return _default_get_amax_and_update_history( - amax_history, - amax_compute_algo, - ) - - -def _compute_scaling_factor( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> torch.Tensor: - """Convert amax to scaling factor.""" - - if recipe.scaling_factor_compute_algo is None: - return _default_sf_compute( - amax, - scale, - fp8_max, - recipe.margin, - ) - return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) - - -def _amax_and_scale_update( - amax_history: torch.Tensor, - scale: torch.Tensor, - fp8_max: float, - recipe: DelayedScaling, -) -> None: - """Updates FP8 meta tensors.""" - new_amax_history, amax = _compute_amax_and_update_history( - amax_history, - recipe.amax_compute_algo, - ) - new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) - scale.copy_(new_scale) - amax_history.copy_(new_amax_history) - - -def split_and_copy( - buffer: torch.Tensor, - outputs: List[torch.Tensor], - chunk_sizes: List[int], -) -> None: - """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" - splits = buffer.split(chunk_sizes) - torch._foreach_copy_(outputs, splits) - - -class RecipeState(abc.ABC): - """Configuration and state for a quantization recipe. - - This is a builder class for quantizers, which are in turn builder - classes for quantized tensors. - - This class may pack together the state for multiple quantizers, - which is helpful for applying fused kernels with less overhead. - - """ - - @staticmethod - def create( - recipe: Recipe, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> RecipeState: - """Factory method to create the state for a quantization recipe - - Parameters - ---------- - recipe: Recipe - Quantization recipe. - mode: {"forward", "backward"} - Training stage where quantization will be performed. - num_quantizers: int, default = 1 - Number of quantizers to create state for. - device: torch.device, default = default CUDA device - Device for quantized tensors. - - Returns - ------- - RecipeState: - Quantization recipe state. - - """ - - cls = None - if recipe.delayed(): - cls = DelayedScalingRecipeState - elif recipe.mxfp8(): - cls = MXFP8BlockScalingRecipeState - elif recipe.float8_current_scaling(): - cls = Float8CurrentScalingRecipeState - elif recipe.float8_block_scaling(): - cls = Float8BlockScalingRecipeState - else: - raise ValueError(f"{recipe.__class__.__name__} is not supported") - return cls( - recipe, - mode=mode, - num_quantizers=num_quantizers, - device=device, - ) - - @abc.abstractmethod - def make_quantizers(self) -> list: - """Convert recipe state to quantizers. - - Quantizers are builder classes for quantized tensors. They are - typically used to convert a high-precision tensor (e.g. in - FP32 or BF16) into a quantized tensor (e.g. in FP8). - - """ - - -class DelayedScalingRecipeState(RecipeState): - """State for FP8 quantization with per-tensor delayed scaling. - - Delayed scaling recipe requires a scaling factor (applied when - casting to FP8) and a history of max-abs values ("amax") from - recent FP8 casts for updating the scaling factor. The scale update - is handled externally by `FP8GlobalStateManager`. - - """ - - recipe: DelayedScaling - mode: str - dtype: tex.DType - scale: torch.Tensor - amax_history: torch.Tensor - - def __init__( - self, - recipe: DelayedScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) - self.amax_history = torch.zeros( - recipe.amax_history_len, - num_quantizers, - dtype=torch.float32, - device=device, - ) - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_tensor import Float8Quantizer - - return [ - Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) - for i in range(self.num_quantizers) - ] - - -class Float8CurrentScalingRecipeState(RecipeState): - """Configuration for Per-tensor current scaling quantization. - - Per-tensor current quantization does not require state. - - """ - - recipe: Float8CurrentScaling - mode: str - dtype: tex.DType - device: torch.device - - def __init__( - self, - recipe: Float8CurrentScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - from .tensor.float8_tensor import Float8CurrentScalingQuantizer - - return [ - Float8CurrentScalingQuantizer(self.dtype, device=self.device) - for i in range(self.num_quantizers) - ] - - -class MXFP8BlockScalingRecipeState(RecipeState): - """Configuration for MXFP8 quantization. - - MXFP8 quantization does not require state. - - """ - - recipe: MXFP8BlockScaling - mode: str - dtype: tex.DType - - def __init__( - self, - recipe: MXFP8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.dtype = get_fp8_te_dtype(recipe, mode == "forward") - - # Allocate buffers - if device is None: - device = torch.device("cuda") - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.mxfp8_tensor import MXFP8Quantizer - - return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] - - -class Float8BlockScalingRecipeState(RecipeState): - """Configuration for Float8BlockScaling quantization. - - Float8BlockScaling quantization does not require state, - but different quantizers use different modes. - """ - - recipe: Float8BlockScaling - mode: str - qx_dtype: tex.DType - qw_dtype: tex.DType - qgrad_dtype: tex.DType - - def __init__( - self, - recipe: Float8BlockScaling, - *, - mode: str, - num_quantizers: int = 1, - device: Optional[torch.device] = None, - ) -> None: - self.recipe = recipe - self.mode = mode - self.num_quantizers = num_quantizers - self.qx_dtype = get_fp8_te_dtype(recipe, True) - self.qw_dtype = get_fp8_te_dtype(recipe, True) - self.qgrad_dtype = get_fp8_te_dtype(recipe, False) - - # Allocate buffers - if device is None: - device = torch.device("cuda") - self.device = device - - def make_quantizers(self) -> list: - # TODO(ksivamani); Find better design for this, adding here to avoid circular import. - from .tensor.float8_blockwise_tensor import Float8BlockQuantizer - - if self.mode == "forward": - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward, and doesn't play nicely with QuantizeOp, - # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qw_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 3) - ] - ) - ) - - assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 2) - ] - ) - ) +# Importing each function instead of 'import *' allows us specify '__all__' in +# quantize.py and also makes any newer additions to quantize.py invisible via +# fp8.py so that we don't reinforce importing internal TE functions. +from .quantization import ( + check_fp8_support, + check_mxfp8_support, + check_nvfp4_support, + check_fp8_block_scaling_support, + check_recipe_support, + get_default_fp8_recipe, + get_fp8_torch_dtype, + get_fp8_te_dtype, + get_fp4_te_dtype, + get_fp8_max, + FP8GlobalStateManager, + fp8_model_init, + fp8_autocast, + _update_amax_history, + _default_get_amax_and_update_history, + _default_sf_compute, + _compute_amax_and_update_history, + _compute_scaling_factor, + _amax_and_scale_update, + split_and_copy, + RecipeState, + DelayedScalingRecipeState, + Float8CurrentScalingRecipeState, + MXFP8BlockScalingRecipeState, + Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, + CustomRecipeState, +) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 4984013ab..f1eba59f3 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -8,6 +8,7 @@ from collections.abc import Iterable import contextlib import gc +import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -17,8 +18,8 @@ from transformer_engine.common.recipe import DelayedScaling, Recipe from transformer_engine.pytorch.constants import dist_group_type -from .fp8 import ( - fp8_autocast, +from .quantization import ( + autocast, FP8GlobalStateManager, get_default_fp8_recipe, ) @@ -86,7 +87,7 @@ def _make_graphed_callables( sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], num_warmup_iters: int = 3, allow_unused_input: bool = False, - fp8_weight_caching: bool = False, + cache_quantized_params: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, _order: Optional[List[int]] = None, _num_layers_per_chunk: Optional[List[int]] = None, @@ -254,7 +255,7 @@ def _make_graphed_callables( consumed_sample_q[sample_keys].append(per_callable_fwd_idx) fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] - if fp8_weight_caching: + if cache_quantized_params: # Initialize flag that controls FP8 weight updates FP8GlobalStateManager.set_skip_fp8_weight_update_tensor(False) @@ -323,14 +324,16 @@ def _make_graphed_callables( fwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] bwd_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] + bwd_dw_graphs = [torch.cuda.CUDAGraph() for _ in range(len(flatten_sample_args))] graph_callables = [None for _ in range(len(flatten_sample_args))] # For cases with multiple active RNG states, e.g. TP. if graph_safe_rng_available(): for _, state in get_all_rng_states().items(): - for fwd_graph, bwd_graph in zip(fwd_graphs, bwd_graphs): + for fwd_graph, bwd_graph, bwd_dw_graph in zip(fwd_graphs, bwd_graphs, bwd_dw_graphs): fwd_graph.register_generator_state(state) bwd_graph.register_generator_state(state) + bwd_dw_graph.register_generator_state(state) mempool = graph_pool_handle() if pool is None else pool @@ -367,21 +370,8 @@ def _make_graphed_callables( ), f"Warmup runs {len(warmup_func)} but only {len(set(warmup_func_idx))} are unique." # Filter the TE modules that cudagraph can access. - visited_te_modules = set() - - def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument - if isinstance(module, TransformerEngineBaseModule): - visited_te_modules.add(module) - # If forward is called on a BasicOperation directly the hook will run - elif isinstance(module, BasicOperation): - visited_te_modules.add(module) - # If forward is called on a te.ops.Sequential it is not called on its constituent ops - elif isinstance(module, Sequential): - assert module._module_groups is not None, "Should have been initialized by warmup" - for module_group in module._module_groups: - if isinstance(module_group, OperationFuser): - for basic_op in module_group._basic_ops: - visited_te_modules.add(basic_op) + visited_te_modules = {} + need_bwd_dw_graph = {} # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): @@ -389,6 +379,31 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[func_idx] kwargs = sample_kwargs[func_idx] static_input_surface = per_callable_static_input_surfaces[func_idx] + + def hook_fn( + module, inputs, outputs, func_idx=func_idx + ): # pylint: disable=unused-argument + modules = set() + if isinstance(module, TransformerEngineBaseModule): + modules.add(module) + # If forward is called on a BasicOperation directly the hook will run + elif isinstance(module, BasicOperation): + modules.add(module) + # If forward is called on a te.ops.Sequential it is not called on its constituent ops + elif isinstance(module, Sequential): + assert ( + module._module_groups is not None + ), "Should have been initialized by warmup" + for module_group in module._module_groups: + if isinstance(module_group, OperationFuser): + for basic_op in module_group._basic_ops: + modules.add(basic_op) + if modules: + if func_idx not in visited_te_modules: + visited_te_modules[func_idx] = modules + else: + visited_te_modules[func_idx].update(modules) + for warmup_iter in range(num_warmup_iters): hooks = [] for module in func.modules(): @@ -433,6 +448,15 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument module_params_with_grad ) per_callable_static_input_surfaces[func_idx] = static_input_surface + + # Run wgrad. This is essential for some TE modules when they have + # delay_wgrad_compute enabled. + need_backward_dw = False + for module in visited_te_modules.get(func_idx, set()): + if hasattr(module, "need_backward_dw") and module.need_backward_dw(): + need_backward_dw = True + module.backward_dw() + need_bwd_dw_graph[func_idx] = need_backward_dw else: grad_inputs = None del outputs, grad_inputs @@ -515,6 +539,17 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument allow_unused=allow_unused_input, retain_graph=retain_graph_in_backward, ) + # If no one module needs the backward_dw, the bwd_dw_graph will be empty. + # So skip capturing it. + if need_bwd_dw_graph[per_callable_bwd_idx]: + bwd_dw_graph = bwd_dw_graphs[per_callable_bwd_idx] + with _graph_context_wrapper(bwd_dw_graph, pool=mempool): + for module in visited_te_modules[per_callable_bwd_idx]: + if ( + hasattr(module, "need_backward_dw") + and module.need_backward_dw() + ): + module.backward_dw() # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs # that don't require grad. I couldn't think of a one-liner for this pattern. @@ -583,10 +618,12 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument # Capture backward graphs in reverse order per_callable_static_grad_outputs = [] per_callable_static_grad_inputs = [] - for static_input_surface, static_outputs, bwd_graph in zip( + for static_input_surface, static_outputs, bwd_graph, bwd_dw_graph, bwd_idx in zip( reversed(per_callable_static_input_surfaces), reversed(per_callable_static_outputs), reversed(bwd_graphs), + reversed(bwd_dw_graphs), + reversed(range(len(per_callable_static_input_surfaces))), ): # For now, assumes all static_outputs require grad static_grad_outputs = tuple( @@ -602,6 +639,11 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument allow_unused=allow_unused_input, retain_graph=retain_graph_in_backward, ) + if need_bwd_dw_graph[bwd_idx]: + with _graph_context_wrapper(bwd_dw_graph, pool=mempool): + for module in visited_te_modules[bwd_idx]: + if hasattr(module, "need_backward_dw") and module.need_backward_dw(): + module.backward_dw() # Constructs a tuple suitable for returning from Graphed.backward: # Pads out the actually-needed grads with Nones in gradient slots for inputs that # don't require grad. I couldn't think of a slick one-liner for this pattern. @@ -691,7 +733,7 @@ def functionalized(*user_args, **user_kwargs): # Decide whether to update FP8 weights skip_fp8_weight_update = None - if fp8_weight_caching: + if cache_quantized_params: assert "is_first_microbatch" in user_kwargs and isinstance( user_kwargs["is_first_microbatch"], bool ), "`is_first_microbatch` boolean kwarg must be provided for FP8 weight caching." @@ -718,6 +760,21 @@ def functionalized(*user_args, **user_kwargs): return functionalized + def make_graphed_attribute_functions(graph_idx): + + # Attach backward_dw as an attribute to the graphed callable. + def backward_dw(): + if need_bwd_dw_graph.get(graph_idx, False): + bwd_dw_graphs[graph_idx].replay() + + # Attach reset as an attribute to the graphed callable. + def reset(): + fwd_graphs[graph_idx].reset() + bwd_graphs[graph_idx].reset() + bwd_dw_graphs[graph_idx].reset() + + return backward_dw, reset + # Put together the final graphed callables ret = [] for i in range(len(sample_args)): @@ -735,9 +792,10 @@ def functionalized(*user_args, **user_kwargs): ) func = graph_callables[i] + te_modules = visited_te_modules.get(i, set()) if isinstance(func, torch.nn.Module): - def make_graphed_forward(func, graph_training_state, graphed, orig_fwd): + def make_graphed_forward(func, graph_training_state, graphed, orig_fwd, te_modules): def new_fwd(*user_args, **user_kwargs): # If the module's training-or-eval state matches what we graphed, # run the graph, otherwise run the original forward method @@ -746,7 +804,7 @@ def new_fwd(*user_args, **user_kwargs): if FP8GlobalStateManager.is_fp8_enabled(): fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() for m in func.modules(): - if m not in visited_te_modules: + if m not in te_modules: # Only Set the FP8 meta for the modules included by forward continue if isinstance(m, TransformerEngineBaseModule): @@ -783,7 +841,7 @@ def new_fwd(*user_args, **user_kwargs): return new_fwd - forward = make_graphed_forward(func, func.training, graphed, func.forward) + forward = make_graphed_forward(func, func.training, graphed, func.forward, te_modules) if _order is None: func.forward = forward ret.append(func) @@ -792,6 +850,10 @@ def new_fwd(*user_args, **user_kwargs): else: ret.append(graphed) + backward_dw_func, reset_func = make_graphed_attribute_functions(i) + setattr(ret[-1], "backward_dw", backward_dw_func) + setattr(ret[-1], "reset", reset_func) + if just_one_callable: return ret[0] @@ -800,14 +862,14 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: Optional[Recipe], + recipe: Optional[Recipe], ) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules with adjusted amax history sizes. """ - if not isinstance(fp8_recipe, DelayedScaling): + if not isinstance(recipe, DelayedScaling): return None fp8_tensors = [] @@ -816,10 +878,10 @@ def save_fp8_tensors( module_tensors = None if isinstance(m, TransformerEngineBaseModule): if m.primary_weights_in_fp8: - m.adjust_amax_history_length(fp8_recipe.amax_history_len) + m.adjust_amax_history_length(recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() elif isinstance(m, BasicOperation): - m.reset_recipe_state(recipe=fp8_recipe) + m.reset_recipe_state(recipe=recipe) module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors @@ -854,11 +916,16 @@ def make_graphed_callables( num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, - fp8_enabled: SingleOrTuple[bool] = False, - fp8_calibrating: bool = False, + fp8_enabled: Optional[SingleOrTuple[bool]] = None, + fp8_calibrating: Optional[bool] = None, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, - fp8_weight_caching: bool = False, + fp8_weight_caching: Optional[bool] = None, + enabled: Optional[SingleOrTuple[bool]] = None, + calibrating: Optional[bool] = None, + recipe: Optional[Recipe] = None, + amax_reduction_group: Optional[dist_group_type] = None, + cache_quantized_params: Optional[bool] = None, _order: Optional[List[int]] = None, _num_layers_per_chunk: Optional[List[int]] = None, pool: Optional[Tuple[int, ...]] = None, @@ -874,6 +941,11 @@ def make_graphed_callables( `original PyTorch implementation `_ for more documentation. + .. warning:: + + Arguments 'fp8_enabled', 'fp8_calibrating', 'fp8_recipe', 'fp8_group', and 'fp8_weight_caching' are deprecated. + Use arguments 'enabled', 'calibrating', 'recipe', 'amax_reduction_group', and 'cache_quantized_params' instead. + Graphing parameters ------------------- modules: (tuple of) callable @@ -898,30 +970,110 @@ def make_graphed_callables( when `_order` is provided. All callables in `modules` are assumed to have inputs and outputs with the same dtype and shape. - FP8-related parameters + Quantization related parameters ---------------------- - fp8_enabled: (tuple of) bool, default = `False` - whether or not to enable fp8. - If tuple, the length must match the number of modules. - fp8_calibrating: bool, default = `False` - calibration mode allows collecting statistics such as amax and scale - data of fp8 tensors even when executing without fp8 enabled. This is - useful for saving an inference ready fp8 checkpoint while training - using a higher precision. - fp8_recipe: Recipe, default = `None` - recipe used for FP8 training. - fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` - distributed group over which amaxes for the fp8 tensors - are reduced at the end of each training step. - fp8_weight_caching: bool, default = `False` - Whether or not to cache FP8 weights across microbatches. if set to `True`, - the `is_first_microbatch` boolean argument must be passed into the forward - method for TransformerEngine modules. When storing primary weights in FP8 - using TE's `fp8_model_init` API and using an FP8 aware optimizer, this arg - must be set to `False` if calculating weight transposes' outside TE, e.g., - in the optimizer step. + enabled: (tuple of) bool, default = `False` + whether or not to enable low precision quantization (FP8/FP4). + If tuple, the length must match the number of modules. + calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of quantized tensors even when executing without quantization enabled. + This is useful for saving an inference ready checkpoint while training + using a higher precision. + recipe: recipe.Recipe, default = `None` + recipe used for low precision quantization. + amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the quantized tensors + are reduced at the end of each training step. + cache_quantized_params: bool, default = `False` + Whether or not to cache quantized weights across microbatches. if set to `True`, + the `is_first_microbatch` boolean argument must be passed into the forward + method for TransformerEngine modules. When storing primary weights in low precision + using TE's `quantized_model_init` API and using an quantization aware optimizer, + this arg must be set to `False` if calculating weight transposes' outside TE, e.g., + in the optimizer step. """ + + # Handle deprecated args. If old kwargs are set, they are prioritized with warning. + if fp8_enabled is not None: + if enabled is not None: + raise ValueError( + "make_graphed_callables has deprecated `fp8_enabled` kwarg " + "in favor of `enabled`, but both kwargs are set." + ) + warnings.warn( + "make_graphed_callables has deprecated `fp8_enabled` kwarg in favor of `enabled`. " + "`fp8_enabled` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + enabled = fp8_enabled + if enabled is None: + enabled = False + + if fp8_calibrating is not None: + if calibrating is not None: + raise ValueError( + "make_graphed_callables has deprecated `fp8_calibrating` kwarg " + "in favor of `calibrating`, but both kwargs are set." + ) + warnings.warn( + "make_graphed_callables has deprecated `fp8_calibrating` kwarg in favor of " + "`calibrating`. `fp8_calibrating` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + calibrating = fp8_calibrating + if calibrating is None: + calibrating = False + + if fp8_recipe is not None: + if recipe is None: + warnings.warn( + "make_graphed_callables has deprecated `fp8_recipe` kwarg in favor of " + "`recipe`. `fp8_recipe` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + raise ValueError( + "make_graphed_callables has deprecated `fp8_recipe` kwarg " + "in favor of `recipe`, but both kwargs are set." + ) + recipe = fp8_recipe + + if fp8_group is not None: + if amax_reduction_group is None: + warnings.warn( + "make_graphed_callables has deprecated `fp8_group` kwarg in favor of " + "`amax_reduction_group`. `fp8_group` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + else: + raise ValueError( + "make_graphed_callables has deprecated `fp8_group` kwarg " + "in favor of `amax_reduction_group`, but both kwargs are set." + ) + amax_reduction_group = fp8_group + + if fp8_weight_caching is not None: + if cache_quantized_params is not None: + raise ValueError( + "make_graphed_callables has deprecated `fp8_weight_caching` kwarg " + "in favor of `cache_quantized_params`, but both kwargs are set." + ) + warnings.warn( + "make_graphed_callables has deprecated `fp8_weight_caching` kwarg in favor of " + "`cache_quantized_params`. `fp8_weight_caching` will be removed in a future release.", + category=DeprecationWarning, + stacklevel=2, + ) + cache_quantized_params = fp8_weight_caching + if cache_quantized_params is None: + cache_quantized_params = False + set_capture_start() # Handle single module. @@ -930,21 +1082,21 @@ def make_graphed_callables( just_one_callable = True modules = (modules,) - if not isinstance(fp8_enabled, tuple): - assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools" - fp8_enabled = (fp8_enabled,) * len(modules) + if not isinstance(enabled, tuple): + assert isinstance(enabled, bool), "enabled must be a bool or a tuple of bools" + enabled = (enabled,) * len(modules) else: - assert len(fp8_enabled) == len( + assert len(enabled) == len( modules - ), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})" - if any(fp8_enabled) and fp8_recipe is None: - fp8_recipe = get_default_fp8_recipe() - elif not any(fp8_enabled): - fp8_recipe = None - module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled)) + ), f"enabled length ({len(enabled)}) must match modules length ({len(modules)})" + if any(enabled) and recipe is None: + recipe = get_default_fp8_recipe() + elif not any(enabled): + recipe = None + module_uses_fp8 = dict(zip((id(m) for m in modules), enabled)) # Store FP8 tensors to reset later. - saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) + saved_fp8_tensors = save_fp8_tensors(modules, recipe=recipe) # FP8 wrapper. old_call_funcs = {} @@ -958,11 +1110,11 @@ def wrap_autocast(block): # Wrap the original call function of the module class. def call_func(self, *args, **kwargs): - with fp8_autocast( + with autocast( enabled=module_uses_fp8.get(id(self), False), - calibrating=fp8_calibrating, - fp8_recipe=fp8_recipe, - fp8_group=fp8_group, + calibrating=calibrating, + recipe=recipe, + amax_reduction_group=amax_reduction_group, _graph=True, ): outputs = old_call_funcs[block_cls](self, *args, **kwargs) @@ -996,7 +1148,7 @@ def call_func(self, *args, **kwargs): sample_args, num_warmup_iters=num_warmup_iters, allow_unused_input=allow_unused_input, - fp8_weight_caching=fp8_weight_caching, + cache_quantized_params=cache_quantized_params, sample_kwargs=sample_kwargs, _order=_order, _num_layers_per_chunk=_num_layers_per_chunk, diff --git a/transformer_engine/pytorch/module/_common.py b/transformer_engine/pytorch/module/_common.py index 1f38b493c..4ba5da68d 100644 --- a/transformer_engine/pytorch/module/_common.py +++ b/transformer_engine/pytorch/module/_common.py @@ -1,27 +1,28 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Internal function used by multiple modules.""" -import os -from typing import Any, List, Optional, Tuple, Union, Callable -from dataclasses import dataclass +import dataclasses import queue +from typing import Any, Callable, List, Optional, Tuple, Union + import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION from .. import cpp_extensions as tex from ..constants import TE_DType -from ..utils import get_default_init_method from ..export import is_in_onnx_export_mode +from ..utils import get_default_init_method if IS_HIP_EXTENSION: from ..triton_kernels.layernorm import te_layernorm_fwd_triton, te_layernorm_bwd_triton from ..triton_kernels.rmsnorm import te_rmsnorm_bwd_triton, te_rmsnorm_fwd_triton + import os def _get_normalization_func(normalization: str, forward: bool): use_rmsnorm_triton = bool( int(os.environ.get('NVTE_USE_RMSNORM_TRITON', '0')) ) and IS_HIP_EXTENSION @@ -188,7 +189,7 @@ def noop_cat( return _NoopCatFunc.apply(dim, *tensors) -@dataclass +@dataclasses.dataclass class _ParameterInitMeta: """ Stores essential metadata needed to support deferred parameter initialization. diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b49e38544..634c188ce 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -19,17 +19,19 @@ import torch import torch.nn.functional as F +from torch.distributed.tensor import DTensor from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe from ._common import _ParameterInitMeta, noop_cat -from ..fp8 import ( +from ..quantization import ( MXFP8BlockScalingRecipeState, DelayedScalingRecipeState, Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, + NVFP4BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, ) @@ -40,18 +42,17 @@ _fsdp_gather_tensors, ) from ..constants import dist_group_type -from ..tensor.quantized_tensor import QuantizedTensor, QuantizedTensorBase, Quantizer +from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer if IS_HIP_EXTENSION: from ..tensor.fsdp2_allgather_tensor import FSDPAGTensor -from ..tensor._internal.float8_tensor_base import Float8TensorBase -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..utils import get_device_compute_capability, is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor @@ -89,7 +90,8 @@ def get_cublas_workspace_size_bytes() -> None: return 33_554_432 """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: - return 33_554_432 + # 32 MiB for NVFP4 GEMM, plus additional 1024 B for alignment and misc scales + return 32 * 1024 * 1024 + 1024 return 4_194_304 @@ -515,7 +517,7 @@ def fill_userbuffers_buffer_for_all_gather( local_tensor: torch.Tensor, quantizer: Optional[Quantizer], process_group, -) -> tuple[torch.Tensor | QuantizedTensorBase, torch.Tensor | QuantizedTensorBase]: +) -> tuple[torch.Tensor | QuantizedTensorStorage, torch.Tensor | QuantizedTensorStorage]: """Fill local shard of Userbuffers buffer with data for all-gather Returns the full tensor and the local shard, both using the @@ -539,7 +541,7 @@ def fill_userbuffers_buffer_for_all_gather( # Unquantized data if quantizer is None: - if isinstance(local_tensor, QuantizedTensorBase): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor = local_tensor.dequantize() if comm.is_fp8_ubuf(): raise RuntimeError( @@ -552,8 +554,8 @@ def fill_userbuffers_buffer_for_all_gather( # FP8 data if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - if not isinstance(local_tensor, Float8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): + if not isinstance(local_tensor, Float8TensorStorage): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor.dequantize() quantizer.set_usage(rowwise=True, columnwise=False) local_tensor = quantizer(local_tensor) @@ -564,7 +566,7 @@ def fill_userbuffers_buffer_for_all_gather( ) comm.copy_into_buffer(local_tensor._data, local_chunk=True) global_tensor_data = comm.get_buffer(shape=global_shape) - global_tensor = Float8TensorBase( + global_tensor = Float8TensorStorage( data=global_tensor_data, fp8_scale_inv=local_tensor._scale_inv, fp8_dtype=local_tensor._fp8_dtype, @@ -576,8 +578,8 @@ def fill_userbuffers_buffer_for_all_gather( if isinstance(quantizer, MXFP8Quantizer): # Cast to MXFP8 if needed - if not isinstance(local_tensor, MXFP8TensorBase): - if isinstance(local_tensor, QuantizedTensorBase): + if not isinstance(local_tensor, MXFP8TensorStorage): + if isinstance(local_tensor, QuantizedTensorStorage): local_tensor.dequantize() local_tensor = quantizer(local_tensor) if not comm.is_fp8_ubuf(): @@ -632,7 +634,7 @@ def fill_userbuffers_buffer_for_all_gather( rowwise_data, rowwise_scale_inv = global_data, global_scale_inv else: columnwise_data, columnwise_scale_inv = global_data, global_scale_inv - global_tensor = MXFP8TensorBase( + global_tensor = MXFP8TensorStorage( rowwise_data=rowwise_data, rowwise_scale_inv=rowwise_scale_inv, columnwise_data=columnwise_data, @@ -675,6 +677,7 @@ def __init__(self) -> None: self.keep_fp8_weight_transpose_cache: bool = True self.use_fsdp2 = False self.wgrad_accumulation_and_reduce_hooks = [] + self.wgrad_store = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -772,6 +775,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: recipe_state, Float8BlockScalingRecipeState ): return + if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd @@ -796,10 +801,10 @@ def _update_weight_quantizers(self) -> None: f"({len(weight_quantizers)}) must match" ) for weight, quantizer in zip(weight_tensors, weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): + if quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_quantizer(quantizer) - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement _get_weight_tensors function" @@ -981,12 +986,13 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: return dtype = inp.dtype - for name, param in self.named_parameters(): - if param is not None: - assert dtype == param.dtype, ( - "Data types for parameters must match when outside of autocasted region. " - f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" - ) + if not self.allow_different_data_and_param_types: + for name, param in self.named_parameters(): + if param is not None: + assert dtype == param.dtype, ( + "Data types for parameters must match when outside of autocasted region. " + f" Found input dtype: {dtype} and {name!r} dtype: {param.dtype}" + ) self.activation_dtype = dtype def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: @@ -1050,8 +1056,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() # Set FP8_MAX per tensor according to recipe - self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd - self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd + if hasattr(self.fp8_meta["recipe"], "fp8_format"): + self.fp8_meta["fp8_max_fwd"] = self.fp8_meta["recipe"].fp8_format.value.max_fwd + self.fp8_meta["fp8_max_bwd"] = self.fp8_meta["recipe"].fp8_format.value.max_bwd # Allocate scales and amaxes self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) @@ -1080,6 +1087,7 @@ def prepare_forward( inp: torch.Tensor, num_gemms: int = 1, allow_non_contiguous: bool = False, + allow_different_data_and_param_types: bool = False, ) -> Generator[torch.Tensor, None, None]: """Checks and prep for FWD. The context manager is needed because there isn't a way for a module to know @@ -1087,6 +1095,7 @@ def prepare_forward( to setup the forward aggregated amax reduction for every module just in case. The autocast exit will pick up the most recent one. """ + self.allow_different_data_and_param_types = allow_different_data_and_param_types self.forwarded_at_least_once = True # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): @@ -1182,9 +1191,9 @@ def grad_output_preprocess( grad_output, ( QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, ), ): grad_output = quantizer(grad_output) @@ -1213,9 +1222,9 @@ def grad_output_preprocess( grad_output_.get_tensor(True), ( QuantizedTensor, - Float8TensorBase, - MXFP8TensorBase, - Float8BlockwiseQTensorBase, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, ), ) and ctx.use_bias @@ -1231,7 +1240,12 @@ def grad_output_preprocess( if ctx.use_bias: if isinstance( grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), + ( + QuantizedTensor, + Float8TensorStorage, + MXFP8TensorStorage, + Float8BlockwiseQTensorStorage, + ), ): grad_bias = grad_output.dequantize().view(-1, grad_output.shape[-1]).sum(dim=0) else: @@ -1240,10 +1254,7 @@ def grad_output_preprocess( grad_bias = grad_output.view(-1, grad_output.shape[-1]).sum(dim=0) else: grad_bias, grad_output = tex.bgrad_quantize(grad_output, quantizer) - if not isinstance( - grad_output, - (QuantizedTensor, Float8TensorBase, MXFP8TensorBase, Float8BlockwiseQTensorBase), - ): + if not isinstance(grad_output, QuantizedTensorStorage): grad_output = quantizer(grad_output) return grad_output, grad_bias @@ -1253,7 +1264,12 @@ def register_parameter(self, name, param, **kwargs): metedata used in deferred initialization. """ super().register_parameter(name, param) - self.param_init_meta[name] = _ParameterInitMeta(**kwargs) + # Initialize param_init_meta exactly once during the init. FSDP2 can call + # register parameter again to change parameters to DTensors. And it calls + # it without custom fp8 specific kwargs that we need. And so we dont want + # to reset/loose our fp8 init attributes. + if hasattr(self, "param_init_meta") and name not in self.param_init_meta: + self.param_init_meta[name] = _ParameterInitMeta(**kwargs) def reset_parameters(self, defer_init: Optional[bool] = False) -> None: """ @@ -1265,10 +1281,14 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: return for name, param in self.named_parameters(recurse=False): + # Check if parameter is a DTensor (FSDP2) or regular tensor + is_dtensor = isinstance(param, DTensor) + dtensor_param = param if is_dtensor else None + # Need to update/quantize local tensor in case of DTensor + param = param._local_tensor if is_dtensor else param # Ensure parameter is on a real device if param.device == torch.device("meta"): param = torch.empty_like(param, device="cuda") - # Initialize the parameter values on device init_fn = self.param_init_meta[name].init_fn get_rng_state_tracker = self.param_init_meta[name].get_rng_state_tracker @@ -1300,6 +1320,15 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: if IS_HIP_EXTENSION and not self.keep_fp8_weight_transpose_cache: quantizer.columnwise_usage=False + if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer): + device_mesh = dtensor_param.device_mesh + amax_reduction_group = ( + device_mesh.get_group(mesh_dim="shard") + if device_mesh.ndim > 1 + else device_mesh.get_group() + ) + quantizer.amax_reduction_group = amax_reduction_group + quantizer.with_amax_reduction = True # Quantize parameter param = quantizer(param) if IS_HIP_EXTENSION and self.use_fsdp2 and not self.primary_weights_in_fp8 and fp8_meta_index is not None: @@ -1315,7 +1344,18 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - param = torch.nn.Parameter(param) + if is_dtensor: + # recreate the DTensor from the parameter. + dtensor_param = DTensor.from_local( + param, + device_mesh=dtensor_param.device_mesh, + placements=dtensor_param.placements, + shape=dtensor_param.size(), + stride=dtensor_param.stride(), + ) + dtensor_param = torch.nn.Parameter(dtensor_param) + else: + param = torch.nn.Parameter(param) # Keep high-precision values on CPU if needed if high_precision_init_val is not None: @@ -1343,8 +1383,12 @@ def clear(self): param._high_precision_init_val = high_precision_init_val param.get_high_precision_init_val = MethodType(get, param) param.clear_high_precision_init_val = MethodType(clear, param) + # Update the parameter based on its type - setattr(self, name, param) + if not is_dtensor: + setattr(self, name, param) + else: + setattr(self, name, dtensor_param) @abstractmethod def forward(self): @@ -1407,14 +1451,14 @@ def get_weight_workspace( # Reset cache if workspace is invalid if out is not None and quantizer is not None: reset_cache = False - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): if ( not is_non_tn_fp8_gemm_supported() and quantizer.columnwise_usage and out._transpose is None ): reset_cache = True - elif isinstance(out, MXFP8TensorBase): + elif isinstance(out, MXFP8TensorStorage): if quantizer.rowwise_usage and out._rowwise_data is None: reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: @@ -1507,12 +1551,21 @@ def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_re """ self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) + def need_backward_dw(self): + """ + Check if this module needs to execute the delayed weight gradient computation. + This method should be used at the beginning of self.backward_dw() to determine if it + should actually be executed or just return without doing anything. + User can also manually call this method to check that before calling into backward_dw(). + """ + return self.wgrad_store is not None and self.wgrad_store.delay_wgrad_compute() + def backward_dw(self): """ Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + if not self.need_backward_dw(): return with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): (wgrad, bgrad), _ = self.wgrad_store.pop() @@ -1549,7 +1602,13 @@ def is_debug_iter(self) -> bool: debug = False else: debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run - self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_last_iteration = TEDebugState.get_iteration() + self.debug_enabled_in_this_iteration = debug + else: + # If this is the same iteration as previous invocation of the module, + # we use the debug value from the first invocation in the iteration. + debug = self.debug_enabled_in_this_iteration + return debug def no_debug_features_active(self, quantizers): @@ -1600,8 +1659,8 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: - MXFP8BlockScaling → MXFP8Tensor - Float8BlockScaling → Float8BlockTensor - Example case to check: recipe is DelayedScaling (DelayedScaling is set in fp8_autocast()), - but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in fp8_model_init()). + Example case to check: recipe is DelayedScaling (DelayedScaling is set in autocast()), + but the weight tensor is MXFP8Tensor (MXFP8BlockScaling is set in quantized_model_init()). """ if not self.fp8 and not self.fp8_calibration: return @@ -1611,7 +1670,7 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: recipe = self.fp8_meta["recipe"] weight_tensors = [getattr(self, name) for name in self.weight_names] for i, tensor in enumerate(weight_tensors): - if isinstance(tensor, QuantizedTensorBase): + if isinstance(tensor, QuantizedTensorStorage): quantizer = tensor._get_quantizer() if quantizer is None: continue @@ -1622,6 +1681,6 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: raise RuntimeError( f"Recipe mismatch for '{self.weight_names[i]}': tensor supports recipe" f" {compatible_recipe_class.__name__}, but got {recipe.__class__.__name__}." - " Please check the recipes assigned during fp8_model_init() and" - " fp8_autocast() calls." + " Please check the recipes assigned during quantized_model_init() and" + " autocast() calls." ) diff --git a/transformer_engine/pytorch/module/fp8_padding.py b/transformer_engine/pytorch/module/fp8_padding.py index 4c6e5f979..415494032 100644 --- a/transformer_engine/pytorch/module/fp8_padding.py +++ b/transformer_engine/pytorch/module/fp8_padding.py @@ -12,7 +12,7 @@ import transformer_engine_torch as tex -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..jit import no_torch_dynamo @@ -80,7 +80,7 @@ class Fp8Padding(torch.nn.Module): number of GEMMs to be performed simultaneously. align_size : int, optional the alignment size for the input tensor. If not provided, the alignment size will - be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first + be determined by the FP8/FP4 recipe (32 for MXFP8/NVFP4 and 16 for others) in the first forward pass. TODO: invesitgate the alignment requirement for non-mxfp8 cases on ROCm @@ -115,7 +115,14 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + self.align_size = ( + 32 + if ( + FP8GlobalStateManager.get_fp8_recipe().mxfp8() + or FP8GlobalStateManager.get_fp8_recipe().nvfp4() + ) + else 16 + ) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/fp8_unpadding.py b/transformer_engine/pytorch/module/fp8_unpadding.py index 3b0f8928f..7a01f1572 100644 --- a/transformer_engine/pytorch/module/fp8_unpadding.py +++ b/transformer_engine/pytorch/module/fp8_unpadding.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..jit import no_torch_dynamo @@ -75,9 +75,9 @@ class Fp8Unpadding(torch.nn.Module): num_gemms : int number of GEMMs to be performed simultaneously. align_size : int, optional - the alignment size for the input tensor. If not provided, the alignment size will - be determined by the FP8 recipe (32 for MXFP8 and 16 for others) in the first - forward pass. + The alignment size for the input tensor. If not provided, the alignment size will + be automatically determined based on the FP8/FP4 recipe in the first forward pass: + 32 for MXFP8 or NVFP4, otherwise 16. """ def __init__( @@ -109,7 +109,14 @@ def forward( assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." if self.align_size is None: - self.align_size = 32 if FP8GlobalStateManager.get_fp8_recipe().mxfp8() else 16 + self.align_size = ( + 32 + if ( + FP8GlobalStateManager.get_fp8_recipe().mxfp8() + or FP8GlobalStateManager.get_fp8_recipe().nvfp4() + ) + else 16 + ) # FP8 padding calculate padded_m_splits = [ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 5749d96c9..1a56a06da 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -13,6 +13,7 @@ from transformer_engine.common.recipe import Recipe from .base import ( + get_dummy_wgrad, get_multi_stream_cublas_workspace, TransformerEngineBaseModule, _2X_ACC_FPROP, @@ -20,7 +21,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..utils import ( divide, cast_if_needed, @@ -40,11 +41,11 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..cpu_offload import is_cpu_offload_enabled +from ..cpu_offload import is_cpu_offload_enabled, mark_not_offload, start_offload from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer -from ..tensor.quantized_tensor import ( - QuantizedTensorBase, +from ..quantized_tensor import ( + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -107,9 +108,15 @@ def forward( is_fp8_activation_recompute_enabled() and not in_fp8_activation_recompute_phase() ) - if weight_quantizers[0] is not None: + # No need to set the quantizer states if weight is already quantized + if weight_quantizers[0] is not None and not isinstance( + weights[0], QuantizedTensorStorage + ): for weight_quantizer in weight_quantizers: weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + elif isinstance(weights[0], QuantizedTensorStorage): + # If weights are already quantized, no need to set quantizer states + weight_quantizers = [weight._quantizer for weight in weights] if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) @@ -128,6 +135,9 @@ def forward( else: inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + if cpu_offloading: + start_offload(*inputmats) + # Initialize weights weights_fp8: list if fp8: @@ -189,6 +199,9 @@ def forward( for i in range(num_gemms): weight_quantizers[i].calibrate(weights[i]) + if cpu_offloading: + mark_not_offload(*weights_fp8, *weights) + if is_grad_enabled: ctx.weight_quantizers = weight_quantizers ctx.weights_shape_1 = weights[0].shape[1] @@ -200,14 +213,23 @@ def forward( inputmats[0] = inp else: for inputmat in inputmats: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) else: inputmats = [None] * num_gemms - if inp.requires_grad: - for weight in weights_fp8: - if isinstance(weight, QuantizedTensorBase): - weight.update_usage(columnwise_usage=True) + + if cpu_offloading: + ctx.grad_added_to_main_grad = hasattr(weights[0], "grad_added_to_main_grad") + + if ctx.grad_added_to_main_grad: + # If you are passing torch.nn.Parameter through the Torch hooks, you will + # get back torch.Tensor. Torch rips off the Parameter wrapper. + # You need to preserve the weight object to have all the attributes user + # sets for the weights. Because of this, it is not recommended to offload + # weights if weights are externally touched outside this module + ctx.weight_objects = [] + for weight in weights: + ctx.weight_objects.append(weight) tensors_to_save, tensor_objects = prepare_for_saving( *inputmats, @@ -271,11 +293,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], biases = saved_tensors[3 * N : 4 * N] main_grads = [main_grad_func() for main_grad_func in ctx.main_grad_funcs] - if ctx.cpu_offloading and ctx.fuse_wgrad_accumulation: - for i in range(ctx.num_gemms): - w = torch.nn.Parameter(weights[i], weights[i].requires_grad) - w.main_grad = main_grads[i] - weights[i] = w + if ctx.cpu_offloading: + if ctx.grad_added_to_main_grad: + for i, weight in enumerate(ctx.weight_objects): + origin_weights[i] = ctx.weight_objects[i] + ctx.weight_objects[i] = None + + if ctx.fuse_wgrad_accumulation: + for i in range(N): + origin_weights[i].main_grad = main_grads[i] # Preprocess grad output grad_output_view = grad_output.contiguous().view(-1, grad_output.shape[-1]) @@ -336,13 +362,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dtype=ctx.activation_dtype, device=ctx.device, ) - - for weight, quantizer in zip(weights, ctx.weight_quantizers): - if quantizer is not None and isinstance(weight, QuantizedTensorBase): - weight.update_usage( - rowwise_usage=quantizer.rowwise_usage, - columnwise_usage=quantizer.columnwise_usage, - ) + # Make sure weights are available in column-wise format + # for dgrad computation. + for weight in weights: + if isinstance(weight, QuantizedTensorStorage): + weight.update_usage(columnwise_usage=True) general_grouped_gemm( weights, grad_output, @@ -402,7 +426,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], use_bias=ctx.use_bias if grad_biases[0] is None else None, bias=biases, use_split_accumulator=wgrad_gemm_use_split_accumulator, - accumulate=accumulate_wgrad_into_param_main_grad, + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): @@ -426,18 +454,15 @@ def handle_custom_ddp_from_mcore(weight, wgrad): ): weight.grad_added_to_main_grad = True if getattr(weight, "zero_out_wgrad", False): - wgrad = torch.zeros( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, + zero=True, ) else: - wgrad = torch.empty( - weight.main_grad.shape, - dtype=weight.dtype, - device=torch.cuda.current_device(), - requires_grad=False, + wgrad = get_dummy_wgrad( + list(weight.main_grad.shape), + weight.dtype, ) elif ctx.fuse_wgrad_accumulation: wgrad = None @@ -519,7 +544,9 @@ class GroupedLinear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the @@ -734,7 +761,7 @@ def forward( produced) """ assert not isinstance( - inp, QuantizedTensorBase + inp, QuantizedTensorStorage ), "GroupedLinear doesn't support input tensor in FP8." assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." @@ -817,7 +844,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + if not self.need_backward_dw(): return with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() @@ -868,16 +895,17 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe self._offsets["input"] + i * self._num_fp8_tensors_per_gemm["bwd"] ].amax_epsilon = recipe.fp8_quant_bwd_grad.amax_epsilon - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] - if not self.fp8 and any(isinstance(w, QuantizedTensorBase) for w in weight_tensors): + if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " "Please make sure this is intentional." ) weight_tensors = [ - w.dequantize() if isinstance(w, QuantizedTensorBase) else w for w in weight_tensors + w.dequantize() if isinstance(w, QuantizedTensorStorage) else w + for w in weight_tensors ] return weight_tensors diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index d1aeebc9f..89af05f93 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -20,6 +20,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_custom from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -30,9 +31,10 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, + assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -58,9 +60,9 @@ from ..jit import no_torch_dynamo from ..graph import is_graph_capturing from ._common import apply_normalization, noop_cat, WeightGradStore -from ..tensor.quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensor, - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -69,10 +71,15 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase +from ..cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_not_offload, + mark_activation_offload, +) +from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage from ..export import is_in_onnx_export_mode, assert_warmed_up -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..cpp_extensions import ( general_gemm, @@ -147,6 +154,8 @@ def forward( if ub_name is not None: nvtx_label = f"{nvtx_label}.{ub_name}" + with_input_all_gather = parallel_mode == "column" and sequence_parallel + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -156,6 +165,7 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") @@ -165,11 +175,13 @@ def forward( ln_bias = cast_if_needed(ln_bias, activation_dtype) nvtx_range_pop(f"{nvtx_label}.norm_input_cast") + if is_cpu_offload_enabled(): + start_offload(inputmat) + tp_world_size = get_distributed_world_size(tp_group) weight_requires_grad = weight.requires_grad backward_needs_input = is_grad_enabled and weight_requires_grad - with_input_all_gather = parallel_mode == "column" and sequence_parallel # Configure Userbuffers communication (comm+GEMM overlap) if debug: # turn off userbuffers in debug mode @@ -202,11 +214,13 @@ def forward( # Avoid quantized norm kernel if norm output will be returned # or if a gather of ln_out must be in high precision. + custom = is_custom(input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not custom # TODO(negvet): and not FP8GlobalStateManager.get_fp8_recipe().custom() ) # ROCm does not currently support quantized norm for Float8CurrentScalingQuantizer @@ -256,7 +270,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = input_quantizer - if not with_quantized_norm: + # custom recipe doesn't need to support quantized AG + if not with_quantized_norm and not custom: ln_out = quantizer(ln_out) quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag_fprop: # Initialize Userbuffers all-gather @@ -285,12 +300,15 @@ def forward( # Prepare weight tensor # ------------------------------------------------------ weightmat = weight - quantized_weight = False + is_weight_param_quantized = False if fp8 or debug: - quantized_weight = not isinstance(weight, QuantizedTensorBase) + is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) # Configure quantizer - if weight_quantizer is not None: + # If weight is already quantized, no need to set quantizer states + if is_weight_param_quantized: + weight_quantizer = weight._quantizer + elif weight_quantizer is not None: weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) # Get quantized weight @@ -414,19 +432,19 @@ def forward( # Input with column-wise usage is needed for wgrad GEMM. if backward_needs_input: - if isinstance(ln_out, QuantizedTensorBase): + if isinstance(ln_out, QuantizedTensorStorage): # For sequence parallel in vanilla FP8, rowwise data is # to gather the input. For MXFP8, columnwise only data # can be allgathered. if ( - isinstance(ln_out, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) + isinstance(ln_out, (MXFP8TensorStorage, Float8BlockwiseQTensorStorage)) or not ctx.ln_out_needs_gather ): ln_out.update_usage(rowwise_usage=False) # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading: @@ -441,14 +459,20 @@ def forward( fsdp_group, mu, rsigma, - weightmat if quantized_weight else None, + weightmat if fp8 and not is_weight_param_quantized else None, ln_out if weight.requires_grad else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") if cpu_offloading: + mark_not_offload( + weightmat, + weight, + bias, + ln_weight, + ln_bias, + ) ctx.grad_added_to_main_grad = hasattr(weight, "grad_added_to_main_grad") - if ctx.grad_added_to_main_grad: # If you are passing torch.nn.Parameter through the Torch hooks, you will # get back torch.Tensor. Torch rips off the Parameter wrapper. @@ -471,7 +495,7 @@ def forward( ctx.tensor_objects = tensor_objects ctx.requires_dgrad = inp_requires_grad ctx.requires_wgrad = weight.requires_grad - ctx.quantized_weight = quantized_weight + ctx.is_weight_param_quantized = is_weight_param_quantized if fuse_wgrad_accumulation and weight.requires_grad: # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates @@ -558,6 +582,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -578,7 +603,7 @@ def backward( ctx.fsdp_shapes, mu, rsigma, - weight if ctx.fp8 and ctx.quantized_weight else None, + weight if ctx.fp8 and not ctx.is_weight_param_quantized else None, ln_out, ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") @@ -588,8 +613,8 @@ def backward( if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: origin_weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - origin_weight.main_grad = main_grad + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + origin_weight.main_grad = main_grad # Configure Userbuffers communication (comm+GEMM overlap) ctx.ub_obj_gradout = None @@ -700,9 +725,9 @@ def backward( # -------------------------------------------------- # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorBase): + if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -825,14 +850,14 @@ def backward( ln_out_total_work.wait() ln_out_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): + if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) ln_out_total = ctx.input_quantizer(ln_out_total) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -868,7 +893,11 @@ def backward( main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None), @@ -1024,7 +1053,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - # if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + # if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): # _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( @@ -1152,7 +1181,9 @@ class LayerNormLinear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the @@ -1474,6 +1505,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif other recipes (mxfp8, etc) def reset_layer_norm_parameters(self) -> None: @@ -1817,7 +1850,29 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # set input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4492abe3e..fb3327156 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -20,6 +20,7 @@ from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch import torch_version +from transformer_engine.pytorch.tensor.utils import is_custom from .base import ( fill_userbuffers_buffer_for_all_gather, get_workspace, @@ -30,7 +31,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -43,6 +44,7 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -67,11 +69,17 @@ Float8Tensor, ) from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.nvfp4_tensor import NVFP4Quantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ._common import apply_normalization, WeightGradStore -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ..tensor.quantized_tensor import ( - QuantizedTensorBase, +from ..cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_not_offload, + mark_activation_offload, +) +from ..quantized_tensor import ( + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, @@ -103,6 +111,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "sreglu": (tex.sreglu, tex.dsreglu, None), "silu": (tex.silu, tex.dsilu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), } if recipe.delayed() or recipe.mxfp8(): # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] @@ -118,10 +127,17 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "sreglu": (tex.sreglu, tex.dsreglu, None), "silu": (tex.silu, tex.dsilu, tex.dbias_dsilu), "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), } # no activation fusion written yet - # Per-tensor current scaling or fp8 blockwise scaling: [] - if recipe.float8_current_scaling() or recipe.float8_block_scaling(): + # Per-tensor current scaling or fp8 blockwise scaling or custom quantization: [] + # TODO(ksivaman): Fuse nvfp4 act once kernel is available. + if ( + recipe.float8_current_scaling() + or recipe.float8_block_scaling() + or recipe.nvfp4() + or recipe.custom() + ): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), @@ -133,6 +149,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): "sreglu": (tex.sreglu, tex.dsreglu, None), "silu": (tex.silu, tex.dsilu, None), "swiglu": (tex.swiglu, tex.dswiglu, None), + "clamped_swiglu": (tex.clamped_swiglu, tex.clamped_dswiglu, None), } raise NotImplementedError(f"Unhandled recipe type {recipe}") @@ -197,6 +214,7 @@ def forward( bwd_ln_sm_margin: int, zero_centered_gamma: bool, activation: str, + activation_params: Optional[dict], normalization: str, ub_overlap_ag: bool, ub_overlap_rs: bool, @@ -220,6 +238,7 @@ def forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) + assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -230,6 +249,8 @@ def forward( ln_weight = cast_if_needed(ln_weight, activation_dtype) if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + if is_cpu_offload_enabled(): + start_offload(inputmat) tp_world_size = get_distributed_world_size(tp_group) backwards_needs_fc1_input = is_grad_enabled and fc1_weight.requires_grad @@ -267,11 +288,13 @@ def forward( # high precision layernorm output and output of the linear are returned # for debug: : layernorm output = High precision to enable processing of this norm + custom = is_custom(fc1_input_quantizer) with_quantized_norm = ( fp8 and not debug and not return_layernorm_output and not return_layernorm_output_gathered + and not custom ) # ROCm does not currently support quantized norm for Float8CurrentScalingQuantizer @@ -315,7 +338,8 @@ def forward( quantizer = None if fp8 or debug: quantizer = fc1_input_quantizer - if not with_quantized_norm: + # custom recipe doesn't need to support quantized AG + if not with_quantized_norm and not custom: ln_out = fc1_input_quantizer(ln_out) fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: @@ -347,8 +371,17 @@ def forward( # which handles weight caching etc. # FP8 cast to workspace buffer update_workspace = is_first_microbatch is None or is_first_microbatch - fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) - fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) + # No need to set the quantizer states if weights are already quantized + if isinstance(fc1_weight, QuantizedTensorStorage): + fc1_weight_quantizer = fc1_weight._quantizer + elif fc1_weight_quantizer is not None: + fc1_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) + + if isinstance(fc2_weight, QuantizedTensorStorage): + fc2_weight_quantizer = fc2_weight._quantizer + elif fc2_weight_quantizer is not None: + fc2_weight_quantizer.set_usage(rowwise=True, columnwise=is_grad_enabled and keep_fp8_weight_transpose_cache) + fc1_weight_final = module.get_weight_workspace( tensor=fc1_weight, quantizer=fc1_weight_quantizer, @@ -444,6 +477,7 @@ def forward( # ACTIVATION - sometimes activation is fused with the GEMM above. fc1_out_without_bias = None + act_params = activation_params or {} if bias_gelu_fusion: fc1_out = None @@ -453,19 +487,27 @@ def forward( act_out, _, fc1_out, _ = fc1_outputs elif debug: fc1_out, *_ = fc1_outputs - act_out = activation_func(fc1_out, None) + act_out = activation_func(fc1_out, None, **act_params) act_out = fc2_input_quantizer(act_out) else: fc1_out, *_ = fc1_outputs - if fp8 and FP8GlobalStateManager.get_fp8_recipe().float8_block_scaling(): - # tex.quantize does not support GELU fusion for blockwise. - act_out = activation_func(fc1_out, None) - act_out = tex.quantize(act_out, fc2_input_quantizer) + if fp8: + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_block_scaling(): + # tex.quantize does not support GELU fusion for blockwise + act_out = activation_func(fc1_out, None, **act_params) + act_out = tex.quantize(act_out, fc2_input_quantizer) + elif recipe.custom(): + # tex.quantize does not support custom quantizers + act_out = activation_func(fc1_out, None, **act_params) + act_out = fc2_input_quantizer(act_out) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) else: if fp8_calibration: - act_out = activation_func(fc1_out, None) + act_out = activation_func(fc1_out, None, **act_params) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + act_out = activation_func(fc1_out, fc2_input_quantizer, **act_params) if not is_grad_enabled: clear_tensor_data(fc1_out) @@ -532,12 +574,11 @@ def forward( # Cache state for backward pass if is_grad_enabled: - # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(fc1_weight_final, QuantizedTensorBase): + if isinstance(fc1_weight_final, QuantizedTensorStorage): fc1_weight_final.update_usage(columnwise_usage=True) - if isinstance(fc2_weight_final, QuantizedTensorBase): + if isinstance(fc2_weight_final, QuantizedTensorStorage): fc2_weight_final.update_usage(columnwise_usage=True) if cpu_offloading: @@ -569,6 +610,19 @@ def forward( if not fc2_weight.requires_grad: clear_tensor_data(act_out) act_out = None + + if cpu_offloading: + mark_not_offload( + ln_weight, + ln_bias, + fc1_weight_final, + fc1_weight, + fc1_bias, + fc2_weight_final, + fc2_weight, + fc2_bias, + ) + tensors_to_save, tensor_objects = prepare_for_saving( inputmat, ln_weight, @@ -623,6 +677,7 @@ def forward( ctx.device = device ctx.activation_dtype = activation_dtype ctx.activation = activation + ctx.activation_params = activation_params ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation @@ -698,6 +753,7 @@ def backward( mu, rsigma, ) = restore_from_saved(ctx.tensor_objects, saved_tensors) + # Delete the references to tensor objects once they've been consumed # by the `restore_from_saved` method to construct back the actual tensors. ctx.tensor_objects = None @@ -838,10 +894,10 @@ def backward( ) # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ctx.fc2_weight_quantizer is not None and isinstance( - fc2_weight, QuantizedTensorBase + fc2_weight, QuantizedTensorStorage ): fc2_weight.update_usage(columnwise_usage=True) @@ -926,14 +982,14 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.fp8 or ctx.debug: - if isinstance(act_out, QuantizedTensorBase): + if isinstance(act_out, QuantizedTensorStorage): act_out.update_usage(columnwise_usage=True) else: ctx.fc2_input_quantizer.set_usage(rowwise=False, columnwise=True) act_out = ctx.fc2_input_quantizer(act_out) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -953,7 +1009,11 @@ def backward( else ctx.activation_dtype ), "quantization_params": ctx.fc2_grad_weight_quantizer, # wgrad in high precision - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(fc1_weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None, @@ -1003,6 +1063,7 @@ def fc2_wgrad_gemm( # -------------------------------------------------- # bias computation + act_params = ctx.activation_params or {} fc1_bias_grad = None fuse_gemm_and_bias_fc1_wgrad = False if ctx.fc1_grad_output_quantizer is not None: @@ -1016,7 +1077,7 @@ def fc2_wgrad_gemm( dact = ctx.fc1_grad_output_quantizer(dact) elif ctx.debug: dact_func = _act_func(ctx.activation)[1] - dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None) + dact = dact_func(fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params) fc1_bias_grad = dact.sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) elif ( @@ -1028,7 +1089,10 @@ def fc2_wgrad_gemm( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[2] fc1_bias_grad, dact = dbias_dact_quantize_func( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), ctx.fc1_grad_output_quantizer + fc2_dgrad, + fc1_out.to(ctx.activation_dtype), + ctx.fc1_grad_output_quantizer, + **act_params, ) # quantize bgrad gelu fused else: # Fusion: gemm + gelu, @@ -1037,12 +1101,15 @@ def fc2_wgrad_gemm( ctx.activation, ctx.fp8_recipe if ctx.fp8 else None )[1] dact = activation_func_bwd( - fc2_dgrad, fc1_out.to(ctx.activation_dtype), None + fc2_dgrad, fc1_out.to(ctx.activation_dtype), None, **act_params ) # activation in high precision if ctx.fp8: - # TODO float8 blockwise current scaling has no bgrad fusion for now - if isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer): + # TODO float8 blockwise current scaling (as well as custom quantizers) has no bgrad fusion for now + if ( + isinstance(ctx.fc1_grad_output_quantizer, Float8BlockQuantizer) + or ctx.fp8_recipe.custom() + ): fc1_bias_grad = dact.view(-1, dact.shape[-1]).sum(dim=0) dact = ctx.fc1_grad_output_quantizer(dact) else: @@ -1088,7 +1155,7 @@ def fc2_wgrad_gemm( # Make sure required data is available if ctx.fc1_weight_quantizer is not None and isinstance( - fc1_weight, QuantizedTensorBase + fc1_weight, QuantizedTensorStorage # this fixes a bug with upstream usage of fc1_weight_quantizer ): fc1_weight.update_usage(columnwise_usage=True) @@ -1162,7 +1229,7 @@ def fc2_wgrad_gemm( ln_out_total_work.wait() ln_out_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(ln_out_total, QuantizedTensorBase): + if isinstance(ln_out_total, QuantizedTensorStorage): ln_out_total.update_usage(columnwise_usage=True) else: ctx.fc1_input_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1172,7 +1239,7 @@ def fc2_wgrad_gemm( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.fp8 or ctx.debug: - if isinstance(dact, QuantizedTensorBase): + if isinstance(dact, QuantizedTensorStorage): dact.update_usage(columnwise_usage=True) else: ctx.fc1_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1195,7 +1262,11 @@ def fc2_wgrad_gemm( else ctx.activation_dtype ), "quantization_params": ctx.fc1_grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(fc2_weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None, "bias": fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None, @@ -1405,6 +1476,7 @@ def fc1_wgrad_gemm( None, # bwd_ln_sm_margin None, # zero_centered_gamma None, # activation + None, # activation_params None, # normalization None, # ub_overlap_ag None, # ub_overlap_rs @@ -1442,7 +1514,11 @@ class LayerNormMLP(TransformerEngineBaseModule): activation : str, default = 'gelu' activation function used. Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', and 'swiglu'. + 'silu', 'swiglu', and 'clamped_swiglu'. + activation_params : dict, default = `None` + Additional parameters for the activation function. + At the moment, only used for 'clamped_swiglu' activation which + supports 'limit' and 'alpha' parameters. init_method : Callable, default = `None` used for initializing FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. @@ -1498,7 +1574,9 @@ class LayerNormMLP(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias for FC2, but instead return the bias value during the forward pass together with the @@ -1555,6 +1633,7 @@ def __init__( bias: bool = True, normalization: str = "LayerNorm", activation: str = "gelu", + activation_params: Optional[dict] = None, output_layer_init_method: Optional[Callable] = None, fuse_wgrad_accumulation: bool = False, params_dtype: Optional[torch.dtype] = None, @@ -1584,6 +1663,7 @@ def __init__( assert normalization in ["LayerNorm", "RMSNorm"], "Unsupported normalization type!" self.use_bias = bias self.activation = activation + self.activation_params = activation_params self.return_bias = return_bias self.apply_bias = bias and not return_bias self.return_layernorm_output = return_layernorm_output @@ -1664,7 +1744,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]: + if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1747,6 +1827,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_layer_norm_parameters(self) -> None: @@ -1917,6 +1999,7 @@ def forward( self.bwd_ln_sm_margin, self.zero_centered_gamma, self.activation, + self.activation_params, self.normalization, self.ub_overlap_ag, self.ub_overlap_rs, @@ -1968,7 +2051,10 @@ def _get_quantizers(self, fp8_output): fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] fc2_input_quantizer.set_usage( rowwise=True, - columnwise=isinstance(fc2_input_quantizer, (MXFP8Quantizer, Float8BlockQuantizer)), + columnwise=isinstance( + fc2_input_quantizer, + (MXFP8Quantizer, Float8BlockQuantizer, NVFP4Quantizer), + ), ) fc1_input_quantizer.internal = True if fp8_output: @@ -2045,6 +2131,19 @@ def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Ten fc1_out = onnx_gemm(fc1_weight, ln_out, fc1_bias) fc1_out = fc1_out.to(torch.float32) # activation is computed in fp32 + act_params = self.activation_params or {} + # Default params for clamped_swiglu in Transformer Engine + clamped_swiglu_limit, clamped_swiglu_alpha = act_params.get("limit", 7.0), act_params.get( + "alpha", 1.702 + ) + + def _clamped_swiglu(x, limit, alpha): + x_glu, x_linear = x.chunk(2, dim=-1) + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + y = out_glu * (x_linear + 1) + return y activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), @@ -2059,6 +2158,9 @@ def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Ten * x.chunk(2, -1)[1], "silu": torch.nn.functional.silu, "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "clamped_swiglu": lambda x: _clamped_swiglu( + x, clamped_swiglu_limit, clamped_swiglu_alpha + ), } if self.activation not in activation_map: raise ValueError(f"Unsupported activation in onnx export: {self.activation}") @@ -2173,7 +2275,29 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT2 ].amax_reduction_group = self.tp_group - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + layernorm_mlp.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.set_parallel_mode: + # fc1_input_quantizer: customize input_quantizer with amax reduction TP group, column parallel + sequence parallel here + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.set_parallel_mode: + # fc2_grad_output_quantizer: customize grad_output_quantizer with amax reduction TP group, row parallel + sequence parallel here + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT2 + ].amax_reduction_group = self.tp_group + + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" return [self.fc1_weight, self.fc2_weight] @@ -2212,7 +2336,7 @@ def backward_dw(self): Execute the delayed weight gradient computation. This method is called after the main backward pass to compute weight gradients. """ - if self.wgrad_store is None or not self.wgrad_store.delay_wgrad_compute(): + if not self.need_backward_dw(): return with torch.cuda.nvtx.range("_LayerNormMLP_wgrad"): (fc2_wgrad, fc2_bias_grad_, *_), tensor_list_fc2 = self.wgrad_store.pop() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 88ed6356b..0d43776f1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -28,7 +28,7 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..utils import ( cast_if_needed, clear_tensor_data, @@ -37,6 +37,7 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, + assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, ) @@ -58,17 +59,23 @@ from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo from ..graph import is_graph_capturing -from ..tensor.quantized_tensor import ( +from ..quantized_tensor import ( QuantizedTensor, - QuantizedTensorBase, + QuantizedTensorStorage, Quantizer, prepare_for_saving, restore_from_saved, ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer +from ..tensor.utils import is_custom from ..export import is_in_onnx_export_mode, assert_warmed_up -from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ..cpu_offload import ( + is_cpu_offload_enabled, + start_offload, + mark_not_offload, + mark_activation_offload, +) from ...debug.pytorch.debug_state import TEDebugState from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -156,6 +163,9 @@ def forward( ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG + # custom recipe check + custom = is_custom(input_quantizer) or is_custom(weight_quantizer) + # ------------------------------------------------------ # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication @@ -166,6 +176,7 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) + assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer @@ -177,7 +188,7 @@ def forward( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorBase): + if not isinstance(inputmat, QuantizedTensorStorage) and not custom: own_quantized_input = True input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( @@ -215,7 +226,7 @@ def forward( else: # Do not all-gather input tensor if fp8 or debug: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat.update_usage(rowwise_usage=True) else: if input_quantizer is None: @@ -228,6 +239,9 @@ def forward( else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP inputmat_total = inputmat + + if is_cpu_offload_enabled(): + start_offload(inputmat) nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # ------------------------------------------------------ # Input tensor is ready for GEMM... @@ -239,7 +253,8 @@ def forward( weightmat = weight if fp8 or debug: # Configure quantizer - if weight_quantizer is not None: + # No need to set the quantizer states if weight is already quantized + if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): columnwise_usage = is_grad_enabled and inp.requires_grad and keep_fp8_weight_transpose_cache if not columnwise_usage and keep_fp8_weight_transpose_cache: columnwise_usage = ( @@ -247,7 +262,9 @@ def forward( and not in_fp8_activation_recompute_phase() ) weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) - + elif isinstance(weight, QuantizedTensor): + # If weight is already quantized, no need to set quantizer states + weight_quantizer = weight._quantizer # Get quantized weight update_workspace = is_first_microbatch is None or is_first_microbatch weightmat = module.get_weight_workspace( @@ -374,7 +391,7 @@ def forward( if ( backward_needs_input and own_quantized_input - and isinstance(inputmat, QuantizedTensorBase) + and isinstance(inputmat, QuantizedTensorStorage) ): if ( ctx.backward_input_needs_gather @@ -393,7 +410,7 @@ def forward( # Weight with column-wise usage is needed for dgrad GEMM while keeping fp8 weight transpose cache. if inp.requires_grad and keep_fp8_weight_transpose_cache and not use_fsdp2: - if isinstance(weightmat, QuantizedTensorBase): + if isinstance(weightmat, QuantizedTensorStorage): weightmat.update_usage(columnwise_usage=True) if cpu_offloading and saved_inputmat is not None: @@ -406,7 +423,7 @@ def forward( ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, - weightmat if fp8 and not isinstance(weight, QuantizedTensorBase) else None, + weightmat if fp8 and not isinstance(weight, QuantizedTensorStorage) else None, ) nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") @@ -421,6 +438,7 @@ def forward( # weights if weights are externally touched outside this module ctx.weight_object = weight + mark_not_offload(weight, weightmat, bias) # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( saved_inputmat, @@ -450,6 +468,7 @@ def forward( ctx.main_grad_func = lambda: weight.main_grad ctx.debug = debug + ctx.custom = custom ctx.cpu_offloading = cpu_offloading ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = bias is not None @@ -515,8 +534,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.cpu_offloading: if ctx.grad_added_to_main_grad: weight = ctx.weight_object - if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: - weight.main_grad = main_grad + if ctx.requires_wgrad and ctx.fuse_wgrad_accumulation: + weight.main_grad = main_grad # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already @@ -618,10 +637,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work = None if ctx.requires_wgrad: if ctx.fp8 or ctx.debug: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass - elif ctx.debug: + elif ctx.debug or ctx.custom: # Debug quantizer will be applied immediately before wgrad GEMM pass else: @@ -637,7 +656,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: - if isinstance(inputmat, QuantizedTensorBase): + if isinstance(inputmat, QuantizedTensorStorage): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype) @@ -682,9 +701,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.requires_dgrad: # Make sure required data is available - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) - if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorBase): + if ctx.weight_quantizer is not None and isinstance( + weight_fp8, QuantizedTensorStorage + ): weight_fp8.update_usage(columnwise_usage=True) # Choose whether to use GEMM kernel with split accumulator @@ -710,6 +731,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # dgrad GEMM # Note: dx = dy * w + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") gemm_out, *_, reduce_scatter_out = general_gemm( weight_fp8, @@ -770,7 +792,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total_work.wait() inputmat_total_work = None if ctx.fp8 or ctx.debug: - if isinstance(inputmat_total, QuantizedTensorBase): + if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: ctx.input_quantizer.set_usage(rowwise=False, columnwise=True) @@ -812,7 +834,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.fp8 or ctx.debug: - if isinstance(grad_output, QuantizedTensorBase): + if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -848,7 +870,11 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype ), "quantization_params": ctx.grad_weight_quantizer, - "accumulate": accumulate_wgrad_into_param_main_grad, + "accumulate": ( + accumulate_wgrad_into_param_main_grad + if not getattr(weight, "overwrite_main_grad", False) + else False + ), "layout": "NT", "out": main_grad if ctx.fuse_wgrad_accumulation else None, "bias": (bias if (grad_bias is None and not ctx.fp8) else None), @@ -967,7 +993,7 @@ def wgrad_gemm( nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") # Scatter fp8 weight buffers - if ctx.fp8 and not isinstance(weight, QuantizedTensorBase): + if ctx.fp8 and not isinstance(weight, QuantizedTensorStorage): _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) return ( wgrad, @@ -1070,7 +1096,9 @@ class Linear(TransformerEngineBaseModule): the weight gradient. When enabled, it is assumed that the weights have an additional `main_grad` attribute (used instead of the regular `grad`) which is a pre-allocated buffer of the correct - size to accumulate gradients in. + size to accumulate gradients in. This argument along with + weight tensor having attribute 'overwrite_main_grad' set to True + will overwrite `main_grad` instead of accumulating. return_bias : bool, default = `False` when set to `True`, this module will not apply the additive bias itself, but instead return the bias value during the forward pass together with the @@ -1363,6 +1391,8 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: self._customize_quantizers_float8_current_scaling(fwd, recipe) elif recipe.float8_block_scaling(): self._customize_quantizers_float8_blockwise_scaling(fwd, recipe) + elif recipe.nvfp4(): + self._customize_quantizers_nvfp4(fwd, recipe) # elif for other recipes (mxfp8, etc.) def reset_parameters(self, defer_init=False): @@ -1452,7 +1482,6 @@ def forward( if not debug else self._get_debug_quantizers(fp8_output, fp8_grad) ) - if debug: if self.no_debug_features_active(quantizers): debug = False @@ -1558,7 +1587,7 @@ def _get_debug_quantizers(self, fp8_output, fp8_grad): for name, q in zip(names, original_quantizers) ) - def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: + def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" unfused_weights = [getattr(self, name) for name in self.weight_names] if any(isinstance(w, QuantizedTensor) for w in unfused_weights): @@ -1694,6 +1723,28 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe tex.FP8BwdTensors.GRAD_OUTPUT1 ].amax_reduction_group = self.tp_group + def _customize_quantizers_nvfp4(self, fwd: bool, recipe: Recipe) -> None: + """Customize quantizers based on current scaling recipe + linear.""" + assert recipe.nvfp4(), "Incorrect recipe." + if fwd: + if self.sequence_parallel and self.parallel_mode == "column": + # customize input_quantizer with amax reduction TP group + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].with_amax_reduction = True + self.quantizers["scaling_fwd"][ + tex.FP8FwdTensors.GEMM1_INPUT + ].amax_reduction_group = self.tp_group + else: + if self.sequence_parallel and self.parallel_mode == "row": + # customize grad_output_quantizer with amax reduction TP group + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].with_amax_reduction = True + self.quantizers["scaling_bwd"][ + tex.FP8BwdTensors.GRAD_OUTPUT1 + ].amax_reduction_group = self.tp_group + def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" if not self.fp8 and not self.fp8_calibration: diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 99bbc34c4..a07ffea43 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -11,19 +11,19 @@ from transformer_engine_torch import FP8TensorMeta from .. import torch_version -from ..fp8 import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager from ..tensor.float8_tensor import Float8Tensor -from ..tensor.quantized_tensor import QuantizedTensorBase +from ..quantized_tensor import QuantizedTensorStorage from ..utils import canonicalize_dtype -def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorBase) -> bool: +def is_quantized_tensor(tensor: torch.Tensor | QuantizedTensorStorage) -> bool: """Check if tensor is a quantized tensor""" - return isinstance(tensor, QuantizedTensorBase) + return isinstance(tensor, QuantizedTensorStorage) def maybe_dequantize( - tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None + tensor: torch.Tensor | QuantizedTensorStorage, dtype: torch.dtype | None = None ) -> torch.Tensor: """Dequantize tensor to given dtype or just convert if not a quantized tensor""" if is_quantized_tensor(tensor): diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 2c903675f..28d49bf7b 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,7 +4,19 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU +from .activation import ( + GELU, + GEGLU, + QGELU, + QGEGLU, + ReLU, + ReGLU, + SReLU, + SReGLU, + SiLU, + SwiGLU, + ClampedSwiGLU, +) from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 22779b601..8a754c638 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -28,6 +28,7 @@ "SReGLU", "SiLU", "SwiGLU", + "ClampedSwiGLU", ] @@ -392,3 +393,38 @@ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dswiglu(*args, **kwargs) + + +class ClampedSwiGLU(_ActivationOperation): + r"""GPT-OSS + Implementation based on `GPT-OSS`__. + + This activation has two differences compared to the original SwiGLU + 1. Both gate and pre-activations are clipped based on parameter limit. + 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation. + + .. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt + from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor. + + Parameters + ---------- + limit: float + The clamp limit. + alpha: float + The scaling factor for the sigmoid function used in the activation. + cache_quantized_input: bool, default = False + Quantize input tensor when caching for use in the backward pass. + """ + + def __init__( + self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False + ): + super().__init__(cache_quantized_input=cache_quantized_input) + self.limit = limit + self.alpha = alpha + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 70c70c54d..432d8c134 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,7 +19,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...fp8 import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, Recipe from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -29,7 +29,7 @@ ) from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -80,7 +80,9 @@ class BasicLinear(BasicOperation): autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be meaningful. This is primarily intented to integrate with - Megatron-LM. + Megatron-LM. This argument along with weight tensor having + attribute 'overwrite_main_grad' set to True will overwrite + `main_grad` instead of accumulating. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -301,8 +303,8 @@ def reset_parameters(self) -> None: "Tried to quantize weight with deferred initialization " "due to meta device, but no quantizer was available. " "This is most likely because the weight was initialized " - "within fp8_model_init, but the forward pass was not " - "performed within fp8_autocast." + "within quantized_model_init, but the forward pass was not " + "performed within autocast." ) quantizer.set_usage( rowwise=True, @@ -322,6 +324,20 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + super().pre_fuser_forward(requires_grad=requires_grad) + if FP8GlobalStateManager.is_fp8_enabled(): + # Configure quantizer usages + # Note: We cache the quantized input for backward pass, + # but discard the quantized weights. + weight_requires_grad = requires_grad and self.weight.requires_grad + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + grad_output_quantizer = self.get_quantizer("backward", 0) + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + weight_quantizer.set_usage(rowwise=True, columnwise=False) + grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) @@ -352,6 +368,35 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: and not getattr(self, "_with_quantized_weight", False) ) + # Recipe-specific configuration + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + if recipe is not None: + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + if recipe.nvfp4(): + if getattr(self, "sequence_parallel", False): + tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None) + if tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + elif tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + @staticmethod def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin @@ -568,7 +613,7 @@ def _functional_forward( # Prepare input tensor for backward pass if weight_requires_grad: if with_quantized_compute and is_quantized_tensor(x_local): - if not (isinstance(x_local, Float8TensorBase) and with_x_all_gather): + if not (isinstance(x_local, Float8TensorStorage) and with_x_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) else: @@ -731,7 +776,7 @@ def _functional_backward( if with_quantized_compute: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - input_quantizer.set_usage(columnwise=True) + input_quantizer.set_usage(rowwise=False, columnwise=True) if with_x_all_gather: x, x_async = gather_along_first_dim( x_local, @@ -912,34 +957,13 @@ def op_forward( input_requires_grad = ctx.requires_grad weight_requires_grad = ctx.requires_grad and self.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = self.get_quantizer("forward", 0) weight_quantizer = self.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer grad_output_quantizer = self.get_quantizer("backward", 0) grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - # Configure quantizers - # Note: We cache the quantized input for backward pass, - # but discard the quantized weights. - input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) - weight_quantizer.set_usage(rowwise=True, columnwise=False) - - recipe = FP8GlobalStateManager.get_fp8_recipe() - if recipe.float8_current_scaling(): - input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -997,6 +1021,7 @@ def op_backward( weight_param = self.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 30ccf5ebc..38b2a59a7 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -11,7 +11,7 @@ import transformer_engine_torch as tex from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_autocast_dtype, maybe_dequantize from ..op import BasicOperation, OperationContext @@ -56,7 +56,7 @@ def op_forward( out = input_ elif impl == "fused": x = input_ - if not isinstance(x, Float8TensorBase): + if not isinstance(x, Float8TensorStorage): x = maybe_dequantize(x, dtype=dtype) out, mask = tex.dropout_fwd(x, self.dropout_probability) elif impl == "unfused": diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index dcfc3c4f7..87c65d4b2 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -9,7 +9,7 @@ import torch -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from .._common import is_quantized_tensor from ..op import BasicOperation, OperationContext from ...tensor import Quantizer @@ -18,8 +18,8 @@ class Quantize(BasicOperation): """Quantize tensor data - Uses FP8 recipe from `fp8_autocast` context. When called outside - of an `fp8_autocast` context, this is an identity operation. + Uses recipe from `autocast` context. When called outside + of an `autocast` context, this is an identity operation. Parameters ---------- diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index 40510c856..7897ef164 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -10,7 +10,7 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.fp8 import Recipe +from transformer_engine.pytorch.quantization import Recipe from transformer_engine.pytorch.ops.basic import Bias from transformer_engine.pytorch.ops.basic.activation import ( _ActivationOperation, diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 845ba262a..a86745a68 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -59,6 +59,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index a9595d516..832e51de8 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -60,6 +60,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 02bcfee0a..74bd3d1b3 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -11,7 +11,7 @@ import torch from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import BasicLinear, Bias from ..op import FusedOperation, FusibleOperation, OperationContext @@ -85,7 +85,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = next_op_input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 15cc081c1..6d5d55339 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -11,7 +11,7 @@ import torch from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, Bias from ..op import FusedOperation, FusibleOperation, OperationContext @@ -79,7 +79,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 21190d4fc..24788bcdf 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -11,7 +11,7 @@ import torch from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..op import ( @@ -58,7 +58,7 @@ def fuser_forward( input_requires_grad = linear_op_ctx.requires_grad weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad - # FP8 metadata + # Quantizers input_quantizer = linear_op.get_quantizer("forward", 0) weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 1ecdba625..fd1820d15 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -21,7 +21,7 @@ get_ub, get_workspace, ) -from ...tensor.quantized_tensor import Quantizer +from ...quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ..basic import BasicLinear, Bias, ReduceScatter @@ -523,6 +523,7 @@ def fuser_backward( weight_param = linear_op.weight if hasattr(weight_param, "__fsdp_param__"): weight_param.main_grad = weight_param.get_main_grad() + accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index a604e57dc..057eb576d 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -14,16 +14,16 @@ from ...cpp_extensions import general_gemm from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size -from ...fp8 import FP8GlobalStateManager +from ...quantization import FP8GlobalStateManager from ...module.base import ( fill_userbuffers_buffer_for_all_gather, get_ub, get_workspace, _2X_ACC_FPROP, ) -from ...tensor.quantized_tensor import Quantizer +from ...quantized_tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer -from ...tensor._internal.float8_tensor_base import Float8TensorBase +from ...tensor.storage.float8_tensor_storage import Float8TensorStorage from .._common import maybe_dequantize, is_quantized_tensor from ..basic import BasicLinear, Bias, ReduceScatter from ..op import ( @@ -267,7 +267,7 @@ def _functional_forward( # Prepare input tensor for backward pass if weight_requires_grad: if with_quantized_compute and is_quantized_tensor(x_local): - if not (isinstance(x_local, Float8TensorBase) and with_ub_all_gather): + if not (isinstance(x_local, Float8TensorStorage) and with_ub_all_gather): # FP8 does not support all-gather of transpose data x_local.update_usage(rowwise_usage=False, columnwise_usage=True) else: diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index df8843649..7eb04fa27 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -13,7 +13,7 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, Recipe, DelayedScaling from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, @@ -34,7 +34,7 @@ fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) -from transformer_engine.pytorch.tensor.quantized_tensor import ( +from transformer_engine.pytorch.quantized_tensor import ( prepare_for_saving, restore_from_saved, ) @@ -480,6 +480,10 @@ def __call__( # Attempt to fuse operations if neccesary self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) + # Initialization before forward + for idx, op in enumerate(self._basic_ops): + op.pre_fuser_forward(requires_grad=idx >= self.first_op_requiring_backward) + # Fuser forward pass if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 903bc49d5..639817ada 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -14,10 +14,10 @@ import torch from transformer_engine.common.recipe import Recipe -from ..fp8 import ( +from ..quantization import ( FP8GlobalStateManager, RecipeState, - fp8_autocast, + autocast, ) from ..tensor import Quantizer @@ -65,6 +65,13 @@ def is_fused_op(self) -> bool: def pre_first_fuser_forward(self) -> None: """Preprocessing before first fuser forward pass""" + def pre_fuser_forward( + self, + *, + requires_grad: bool, # pylint: disable=unused-argument + ) -> None: + """Preprocessing before fuser forward pass""" + def get_input_quantizer(self) -> Optional[Quantizer]: """Get builder class for quantized input tensor""" @@ -588,6 +595,9 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: extra[key] = val state[mode]["extra_fp8_variables"] = extra + if not state: + return torch.empty(0, dtype=torch.uint8) + # Serialize state into byte tensor torch.cuda.synchronize() state_serialized = bytearray(pickle.dumps(state)) @@ -624,7 +634,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: # Get op's quantizer state, initializing if needed if self._fp8_metas is None or self._fp8_metas[mode] is None: - with fp8_autocast(fp8_recipe=state[mode]["recipe"]): + with autocast(recipe=state[mode]["recipe"]): self.reset_recipe_state(recipe=state[mode]["recipe"]) fp8_meta = self._fp8_metas[mode] @@ -710,6 +720,10 @@ def pre_first_fuser_forward(self) -> None: for op in self.basic_ops: op.pre_first_fuser_forward() + def pre_fuser_forward(self, *, requires_grad: bool) -> None: + for op in self.basic_ops: + op.pre_fuser_forward(requires_grad=requires_grad) + def forward( self, input: torch.Tensor, # pylint: disable=redefined-builtin diff --git a/transformer_engine/pytorch/optimizers/fused_adam.py b/transformer_engine/pytorch/optimizers/fused_adam.py index 52864904c..a40fe9302 100644 --- a/transformer_engine/pytorch/optimizers/fused_adam.py +++ b/transformer_engine/pytorch/optimizers/fused_adam.py @@ -376,9 +376,9 @@ def _initialize_state( """ dtype = self.name_to_dtype_map[state_name] if store_param_remainders: - data = torch.zeros_like(param, dtype=torch.int16) + data = torch.zeros(param.shape, dtype=torch.int16, device=param.device) else: - data = torch.empty_like(param, dtype=dtype) + data = torch.empty(param.shape, dtype=dtype, device=param.device) if zero_buffer: data.zero_() diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index ea3e67a57..f73bc9a96 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -10,7 +10,7 @@ import transformer_engine_torch as tex import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.tensor.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py new file mode 100644 index 000000000..d8dff33d5 --- /dev/null +++ b/transformer_engine/pytorch/quantization.py @@ -0,0 +1,1425 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Quantization utilities for TransformerEngine""" +from __future__ import annotations + +import abc +import itertools +import functools +import warnings +import os +from contextlib import contextmanager +from collections import deque +from typing import Callable, List, Optional, Dict, Any, Tuple, Union + +import torch +import transformer_engine_torch as tex +from transformer_engine.common.recipe import ( + Recipe, + DelayedScaling, + Format, + MXFP8BlockScaling, + Float8CurrentScaling, + Float8BlockScaling, + NVFP4BlockScaling, + CustomRecipe, +) + +from .constants import dist_group_type +from .utils import get_device_compute_capability +from .jit import jit_fuser + +from torch.utils.cpp_extension import IS_HIP_EXTENSION +from .utils import get_torch_float8_e4m3_type, get_torch_float8_e5m2_type + +__all__ = [ + "autocast", + "quantized_model_init", + "is_fp8_available", + "is_mxfp8_available", + "is_fp8_block_scaling_available", + "is_nvfp4_available", + "get_default_recipe", +] + + +@functools.lru_cache(maxsize=None) +def check_fp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if IS_HIP_EXTENSION: + gpu_arch = get_device_compute_capability() + if gpu_arch in ((9, 4), (9, 5)): + return True, "" + else: + return False, "Device arch gfx94x or gfx95x required for FP8 execution." + if get_device_compute_capability() >= (9, 0): # hopper and above + return True, "" + if get_device_compute_capability() < (8, 9): # pre-ada + return False, "Device compute capability 8.9 or higher required for FP8 execution." + if tex.get_cublasLt_version() < 120103: + return False, "CublasLt version 12.1.3.x or higher required for FP8 execution on Ada." + if float(torch.version.cuda) < 12.1: + return False, "Cuda version 12.1 or higher required for FP8 execution on Ada." + return True, "" + + +@functools.lru_cache(maxsize=None) +def check_mxfp8_support() -> Tuple[bool, str]: + """Return if fp8 support is available""" + if IS_HIP_EXTENSION: + if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") == "0": + return False, "MXFP8 support is not enabled." + gpu_arch = get_device_compute_capability() + if gpu_arch == (9, 5): + return True, "" + return False, "Gfx95x is required for MXFP8 execution." + if get_device_compute_capability() >= (12, 0): + return False, "MXFP8 (for all gemm layouts) is not supported on 12.0+ architectures yet." + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for MXFP8 execution." + + +@functools.lru_cache(maxsize=None) +def check_nvfp4_support() -> Tuple[bool, str]: + if IS_HIP_EXTENSION: + return False, "ROCm TE currently not supporting NVFP4" + """Return if nvfp4 support is available""" + if get_device_compute_capability() >= (10, 0): # blackwell and above + return True, "" + return False, "Device compute capability 10.0 or higher required for NVFP4 execution." + + +@functools.lru_cache(maxsize=None) +def check_fp8_block_scaling_support() -> Tuple[bool, str]: + """Return if fp8 block scaling support is available""" + if IS_HIP_EXTENSION: + return False, "FP8 block scaled gemm not yet supported for ROCm" + if get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.9: + return True, "" + return ( + False, + "FP8 block scaled GEMM requires compute capability 9.0 or higher and CUDA >= 12.9.", + ) + + +def check_recipe_support(recipe: Recipe) -> None: + """Check if the given recipe is supported.""" + recipe_supported = True + unsupported_reason = "" + if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): + recipe_supported, unsupported_reason = check_fp8_support() + elif isinstance(recipe, Float8BlockScaling): + recipe_supported, unsupported_reason = check_fp8_block_scaling_support() + elif isinstance(recipe, MXFP8BlockScaling): + recipe_supported, unsupported_reason = check_mxfp8_support() + assert recipe_supported, unsupported_reason + + +def get_default_fp8_recipe() -> Recipe: + """FP8 recipe with default args.""" + if IS_HIP_EXTENSION: + if os.getenv("NVTE_ROCM_ENABLE_MXFP8", "0") != "2": + return DelayedScaling() + gpu_arch = get_device_compute_capability() + if gpu_arch == (9, 5): + return MXFP8BlockScaling() + return DelayedScaling() + if check_mxfp8_support()[0]: + return MXFP8BlockScaling() + if get_device_compute_capability() >= (12, 0): + # This is a temporary restriction until MXFP8 is supported for all gemm layouts. + return Float8CurrentScaling() + return DelayedScaling() + + +def get_default_recipe() -> Recipe: + """Returns the default training recipe based on available device.""" + return get_default_fp8_recipe() + + +def get_fp8_torch_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> torch.dtype: + """Get fp8 data type according to recipe and tensor""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return get_torch_float8_e4m3_type() + return get_torch_float8_e5m2_type() + + +def get_fp8_te_dtype(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: + """Get fp8 data type according to recipe and tensor""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return tex.DType.kFloat8E4M3 + return tex.DType.kFloat8E5M2 + + +def get_fp4_te_dtype(fp4_recipe: Recipe) -> tex.DType: + """Get fp4 data type according to recipe and tensor""" + if fp4_recipe.fp4_format == Format.E2M1: + return tex.DType.kFloat4E2M1 + raise ValueError(f"Unsupported FP4 format: {fp4_recipe.fp4_format}") + + +def get_fp8_max(fp8_recipe: Recipe, fprop_tensor: bool = True) -> tex.DType: + """Get max representible FP8 value.""" + if fp8_recipe.fp8_format == Format.E4M3 or ( + fp8_recipe.fp8_format == Format.HYBRID and fprop_tensor + ): + return Format.E4M3.value.max_fwd + return Format.E5M2.value.max_fwd + + +def is_fp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if FP8 support is available for the delayed + scaling and per tensor current scaling recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for FP8 is available. + + """ + if return_reason: + return check_fp8_support() + return check_fp8_support()[0] + + +def is_mxfp8_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for the MXFP8 recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for MXFP8 is available. + + """ + if return_reason: + return check_mxfp8_support() + return check_mxfp8_support()[0] + + +def is_fp8_block_scaling_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for the FP8 block scaling recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for FP8 block scaling is available. + + """ + if return_reason: + return check_fp8_block_scaling_support() + return check_fp8_block_scaling_support()[0] + + +def is_nvfp4_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine if support is available for the NVFP4 recipe. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when required support is not available. The reason + will be an empty string if support for NVFP4 is available. + + """ + if return_reason: + return check_nvfp4_support() + return check_nvfp4_support()[0] + + +class FP8GlobalStateManager: + """Class to keep track of and manipulate the global + FP8 state at different stages of execution. + """ + + FP8_ENABLED = False + FP8_CALIBRATION = False + FP8_RECIPE = None + FP8_DISTRIBUTED_GROUP = None + FP8_PARAMETERS = False + HIGH_PRECISION_INIT_VAL = False + IS_FIRST_FP8_MODULE = False + FP8_GRAPH_CAPTURING = False + SKIP_FP8_REDUCTION_FOR_FSDP2 = False + AUTOCAST_DEPTH = 0 + global_amax_buffer = {} + global_amax_history_buffer = {} + global_scale_buffer = {} + fp8_tensors_recompute_buffer = [] + fp8_available = None + reason_for_no_fp8 = "" + autocast_arguments = {} + skip_fp8_weight_update_tensor = None + mxfp8_available = None + reason_for_no_mxfp8 = "" + fp8_block_scaling_available = None + reason_for_no_fp8_block_scaling = None + nvfp4_available = None + reason_for_no_nvfp4 = "" + + @classmethod + def reset(cls) -> None: + """Reset the global state""" + cls.FP8_ENABLED = False + cls.FP8_CALIBRATION = False + cls.FP8_RECIPE = None + cls.FP8_DISTRIBUTED_GROUP = None + cls.FP8_PARAMETERS = False + cls.HIGH_PRECISION_INIT_VAL = False + cls.IS_FIRST_FP8_MODULE = False + cls.FP8_GRAPH_CAPTURING = False + cls.AUTOCAST_DEPTH = 0 + cls.global_amax_buffer = {} + cls.global_amax_history_buffer = {} + cls.global_scale_buffer = {} + cls.fp8_tensors_recompute_buffer = [] + cls.fp8_available = None + cls.reason_for_no_fp8 = "" + cls.autocast_arguments = {} + cls.skip_fp8_weight_update_tensor = None + cls.mxfp8_available = None + cls.reason_for_no_mxfp8 = "" + cls.fp8_block_scaling_available = None + cls.reason_for_no_fp8_block_scaling = "" + + @classmethod + def set_skip_fp8_weight_update_tensor(cls, skip: bool) -> None: + """`skip_fp8_weight_update_tensor` inplace setter.""" + if cls.skip_fp8_weight_update_tensor is None: + cls.skip_fp8_weight_update_tensor = torch.empty(1, dtype=torch.float32, device="cuda") + cls.skip_fp8_weight_update_tensor.fill_(skip) + + @classmethod + def get_skip_fp8_weight_update_tensor(cls) -> None: + """`skip_fp8_weight_update_tensor` getter.""" + return cls.skip_fp8_weight_update_tensor + + @classmethod + def is_fp8_available(cls) -> Tuple[bool, str]: + """Return if fp8 support is available""" + return check_fp8_support() + + @classmethod + def is_mxfp8_available(cls) -> Tuple[bool, str]: + """Return if MXFP8/current scaling support is available.""" + return check_mxfp8_support() + + @classmethod + def is_fp8_block_scaling_available(cls) -> Tuple[bool, str]: + """Return if Float8 block scaling support is available.""" + return check_fp8_block_scaling_support() + + @classmethod + def is_nvfp4_available(cls) -> Tuple[bool, str]: + """Return if NVFP4 support is available.""" + return check_nvfp4_support() + + @staticmethod + def get_meta_tensor_key(forward: bool = True) -> str: + """Returns scaling key in `fp8_meta`.""" + if forward: + return "scaling_fwd" + return "scaling_bwd" + + @staticmethod + def get_fwd_bwd_key(forward: bool = True) -> str: + """Convert bool `forward` to string.""" + return "forward" if forward else "backward" + + @classmethod + def get_buffer_info(cls) -> str: + """ + Returns a key for `fp8_meta` that stores the module's index + in the global buffers along with autocast information. + """ + return "buffer_index_and_autocast_key" + + @classmethod + def get_key_in_buffer( + cls, + forward: bool, + fp8_recipe: Recipe, + fp8_group: dist_group_type, + ) -> str: + """Returns a key into the global FP8 buffers.""" + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + fwd_bwd_key = cls.get_fwd_bwd_key(forward) + return f"{fwd_bwd_key}_{autocast_key}" + + @classmethod + def split_key_in_buffer(cls, key: str) -> Tuple[bool, str]: + """Splits buffer key into relevant parts.""" + forward, autocast_key = key.split("_", 1) + forward = forward == "forward" + return forward, autocast_key + + @classmethod + def add_fp8_tensors_to_global_buffer( + cls, + fp8_meta: Dict[str, Any], + ) -> None: + """ + Delayed scaling only. + + The amax reduction process happens completely outside the FP8 modules. + To participate in the reduction, the only role played by a module is + to call this function in order to append it's FP8 tensor into a global + buffer. There are 5 global buffers maintained, one each for amax, amax + history, scale, scale-inverse, and non-weight-mask. Each buffer has + keys that hold FP8 tensors. Keys have a `forward_` or `backward_` prefix + to indicate the type of FP8 tensor, since the forward and backward + reductions happen separately. + + Note: For CG capture, this method is called from the graphed + wrapper. For non CG case, it's called from within the module. + """ + + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + # Every module must call this function exactly once since + # the amax tensors are static. Ensures that compatibility + # with non-graphed modules is maintained. + index_in_buffer = cls.get_buffer_info() # Same index for fwd/bwd fp8 tensors. + if index_in_buffer in fp8_meta: + return + + fp8_meta[index_in_buffer] = [] + for forward in (True, False): + fp8_meta_tensor_key = cls.get_meta_tensor_key(forward=forward) + if fp8_meta_tensor_key not in fp8_meta: + # Handles non-parameter FP8 modules, e.g. DPA. + continue + + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) + + if key not in cls.global_amax_buffer: + cls.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] + cls.global_amax_history_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history] + cls.global_scale_buffer[key] = [fp8_meta[fp8_meta_tensor_key].scale] + else: + cls.global_amax_buffer[key].append(fp8_meta[fp8_meta_tensor_key].amax_history[0]) + cls.global_amax_history_buffer[key].append( + fp8_meta[fp8_meta_tensor_key].amax_history + ) + cls.global_scale_buffer[key].append(fp8_meta[fp8_meta_tensor_key].scale) + fp8_meta[index_in_buffer].append(len(cls.global_amax_buffer[key]) - 1) + fp8_meta[index_in_buffer].append(key) + + @classmethod + def is_fp8_enabled(cls) -> bool: + """Is FP8 enabled""" + return cls.FP8_ENABLED + + @classmethod + def is_fp8_calibration(cls) -> bool: + """Is FP8 calibration""" + return cls.FP8_CALIBRATION + + @classmethod + def with_fp8_parameters(cls) -> bool: + """Should the parameters be stored as FP8""" + return cls.FP8_PARAMETERS + + @classmethod + def with_high_precision_init_val(cls) -> bool: + """Should the high precision initial values be stored with FP8 parameters""" + return cls.HIGH_PRECISION_INIT_VAL + + @classmethod + def fp8_graph_capturing(cls) -> bool: + """Is CUDA graph capture under way?""" + return cls.FP8_GRAPH_CAPTURING or torch.cuda.is_current_stream_capturing() + + @classmethod + def is_first_fp8_module(cls): + """Returns `True` only the first time when called multiple + times from within the same `autocast` context. + """ + tmp = cls.IS_FIRST_FP8_MODULE + cls.IS_FIRST_FP8_MODULE = False + return tmp + + @classmethod + def get_fp8_recipe(cls) -> Recipe: + """Return the fp8 recipe""" + if cls.FP8_RECIPE is not None: + return cls.FP8_RECIPE + return get_default_fp8_recipe() + + @classmethod + def get_fp8_group(cls) -> Union[dist_group_type, None]: + """Return the fp8 group for scale/amax comm""" + return cls.FP8_DISTRIBUTED_GROUP + + @classmethod + def get_autocast_state(cls) -> Tuple[bool, bool, Recipe, dist_group_type, bool]: + """FP8 autocast state getter""" + return ( + cls.FP8_ENABLED, + cls.FP8_CALIBRATION, + cls.FP8_RECIPE, + cls.FP8_DISTRIBUTED_GROUP, + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING, + ) + + @classmethod + def set_autocast_state( + cls, fp8_state: Tuple[bool, bool, DelayedScaling, dist_group_type, bool] + ) -> None: + """FP8 autocast state setter""" + ( + cls.FP8_ENABLED, + cls.FP8_CALIBRATION, + cls.FP8_RECIPE, + cls.FP8_DISTRIBUTED_GROUP, + cls.IS_FIRST_FP8_MODULE, + cls.FP8_GRAPH_CAPTURING, + ) = fp8_state + + @staticmethod + def reduce_tensor_across_group_op_max(tensor: torch.Tensor, group: dist_group_type) -> None: + """Reduce tensor across given group.""" + if torch.distributed.is_initialized(): + torch.distributed.all_reduce( + tensor, + op=torch.distributed.ReduceOp.MAX, + group=group, + async_op=False, + ) + + @classmethod + def reduce_and_update_fp8_tensors( + cls, + forward: bool = True, + ) -> None: + """Delayed scaling only. Concatenate, reduce, and split amaxes in the global buffer.""" + # global_amax_buffer should only be non-empty for fp8 delayed scaling + for buffer_key, amax_buffer in cls.global_amax_buffer.items(): + # Check for forward or backward reduction. + fwd_update, autocast_key = cls.split_key_in_buffer(buffer_key) + if fwd_update != forward: + continue + if len(amax_buffer) == 0: + continue + + # Retrieve autocast specific args and concat amaxes. + recipe, group = cls.autocast_arguments[autocast_key] + contiguous_amax = torch.cat(amax_buffer) + + # Reduction. + if ( + recipe.reduce_amax + and torch.distributed.is_initialized() + and torch.distributed.get_world_size(group=group) > 1 + ): + cls.reduce_tensor_across_group_op_max(contiguous_amax, group) + + # Amax and scale update. + unfused_update = ( + bool(int(os.getenv("NVTE_UNFUSED_FP8_UPDATE", "0"))) + or callable(recipe.amax_compute_algo) + or callable(recipe.scaling_factor_compute_algo) + ) + + if not unfused_update: + tex.fused_amax_and_scale_update_after_reduction( + contiguous_amax, + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + recipe.amax_compute_algo, + get_fp8_te_dtype(recipe, forward), + recipe.margin, + ) + else: + split_and_copy(contiguous_amax, amax_buffer, [x.numel() for x in amax_buffer]) + + for amax_history, scale in zip( + cls.global_amax_history_buffer[buffer_key], + cls.global_scale_buffer[buffer_key], + ): + _amax_and_scale_update( + amax_history, scale, get_fp8_max(recipe, forward), recipe + ) + + @classmethod + def get_unique_autocast_key( + cls, + recipe: Optional[Recipe] = None, + group: Optional[dist_group_type] = None, + ): + """ + For FP8, each autocast can be uniquely identified by the recipe and fp8 group. + Safely using `hash` as we never cross checkpoint boundaries. + """ + return f"{str(recipe)}:{hash(group)}" + + @classmethod + def autocast_enter( + cls, + enabled: bool = False, + calibrating: bool = False, + fp8_recipe: Optional[Recipe] = None, + fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, + ) -> None: + """Set state and tracking variables for entry into FP8 region.""" + + fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + autocast_key = cls.get_unique_autocast_key(fp8_recipe, fp8_group) + cls.autocast_arguments[autocast_key] = (fp8_recipe, fp8_group) + + cls.FP8_ENABLED = enabled + cls.FP8_CALIBRATION = calibrating + cls.FP8_RECIPE = fp8_recipe + cls.FP8_DISTRIBUTED_GROUP = fp8_group + cls.FP8_GRAPH_CAPTURING = _graph + + if cls.AUTOCAST_DEPTH == 0: + cls.IS_FIRST_FP8_MODULE = True + cls.AUTOCAST_DEPTH += 1 + + if enabled: + fp8_available, reason_for_no_fp8 = cls.is_fp8_available() + assert fp8_available, reason_for_no_fp8 + if isinstance(fp8_recipe, MXFP8BlockScaling): + mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available() + assert mxfp8_available, reason_for_no_mxfp8 + if isinstance(fp8_recipe, Float8BlockScaling): + fp8_block_available, reason_for_no_fp8_block = cls.is_fp8_block_scaling_available() + assert fp8_block_available, reason_for_no_fp8_block + if isinstance(fp8_recipe, NVFP4BlockScaling): + nvfp4_available, reason_for_no_nvfp4 = cls.is_nvfp4_available() + assert nvfp4_available, reason_for_no_nvfp4 + + @classmethod + def autocast_exit(cls, enabled: bool, _graph: bool) -> None: + """Set state and tracking variables for exit from FP8 region.""" + cls.AUTOCAST_DEPTH -= 1 + # Reduce only the non-FP8 weight modules here. + # FP8 weight modules are reduced at the end of the optimizer + # step after the weight amax is populated. + if enabled and cls.AUTOCAST_DEPTH == 0 and not _graph and torch.is_grad_enabled(): + # delayed scaling only function, for other recipes (current scaling with any granularity), + # this is noop for other recipes because cls.global_amax_buffer is empty list + cls.reduce_and_update_fp8_tensors(forward=True) + + @classmethod + def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: + """Copy the scaling factors and amaxes for recompute forward phase + to ensure both forward steps are numerically same. + """ + + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" + + to_copy = [ + fp8_meta["scaling_fwd"].amax_history.clone(), + fp8_meta["scaling_fwd"].scale.clone(), + ] + + if buffer_position_key in fp8_meta: + cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].append(to_copy) + else: + if len(cls.fp8_tensors_recompute_buffer) == 0: + cls.fp8_tensors_recompute_buffer = [deque()] + else: + cls.fp8_tensors_recompute_buffer.append(deque()) + cls.fp8_tensors_recompute_buffer[-1].append(to_copy) + fp8_meta[buffer_position_key] = len(cls.fp8_tensors_recompute_buffer) - 1 + + @classmethod + def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> None: + """Switch to the copied scaling factors and amaxes from phase + 1 forward for indentical numerical outputs. + """ + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + # Store updated amaxes and scales from phase 1 post forward. + fp8_meta["updated_amax_history_fwd"] = fp8_meta["scaling_fwd"].amax_history.clone() + fp8_meta["updated_scale_fwd"] = fp8_meta["scaling_fwd"].scale.clone() + + # Retrieve stashed amaxes and scales from phase 1 pre forward. + buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" + stashed_fp8_meta = cls.fp8_tensors_recompute_buffer[fp8_meta[buffer_position_key]].popleft() + + # Replace amaxes and scales with stashed values for phase 2 forward + fp8_meta["scaling_fwd"].amax_history.copy_(stashed_fp8_meta[0]) + fp8_meta["scaling_fwd"].scale.copy_(stashed_fp8_meta[1]) + + @staticmethod + def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: + """Restore latest scaling factors and amaxes after recompute forward run.""" + # delayed scaling only function, noop for any other recipe + if not fp8_meta["recipe"].delayed(): + return + + fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) + fp8_meta["scaling_fwd"].scale.copy_(fp8_meta["updated_scale_fwd"]) + + +@contextmanager +def fp8_model_init( + enabled: bool = True, + recipe: Optional[Recipe] = None, + preserve_high_precision_init_val: bool = False, +) -> None: + """ + .. warning:: + + fp8_model_init is deprecated and will be removed in a future release. Use + quantized_model_init(enabled=..., recipe=..., preserve_high_precision_init_val=...) instead. + + """ + + warnings.warn( + "fp8_model_init is deprecated and will be removed in a future release. " + "Use quantized_model_init(" + "enabled=..., recipe=..., preserve_high_precision_init_val=...) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + # Call new implementation. + with quantized_model_init( + enabled=enabled, + recipe=recipe, + preserve_high_precision_init_val=preserve_high_precision_init_val, + ): + yield + + +@contextmanager +def quantized_model_init( + enabled: bool = True, + recipe: Optional[Recipe] = None, + preserve_high_precision_init_val: bool = False, +) -> None: + """ + Context manager for initialization of quantized parameters. + + Example usage: + + .. code-block:: python + + with quantized_model_init(enabled=True): + model = transformer_engine.pytorch.Linear(768, 768) + + # Preserving high precision initial value to initialize master weight + with quantized_model_init(enabled=True, preserve_high_precision_init_val=True): + model = transformer_engine.pytorch.Linear(768, 768) + master_weight = model.weight.get_high_precision_init_val() + model.weight.clear_high_precision_init_val() + + Parameters + ---------- + enabled: bool, default = `True` + when enabled, Transformer Engine modules created inside this `quantized_model_init` + region will hold only quantized copies of its parameters, as opposed to the default + behavior where both higher precision and quantized copies are present. Setting this + option to `True` may result in lower memory consumption and is especially + useful for scenarios like: + + * full model training using optimizer with master weights, where the high + precision copies of weights are already present in the optimizer. + * inference, where only the quantized copies of the parameters are used. + * LoRA-like fine-tuning, where the main parameters of the model do not change. + recipe: transformer_engine.common.recipe.Recipe, default = `None` + Recipe used to create the parameters. If left to None, it uses the default recipe. + preserve_high_precision_init_val: bool, default = `False` + when enabled, store the high precision tensor used to initialize quantized parameters + in CPU memory, and add two function attributes named `get_high_precision_init_val()` + and `clear_high_precision_init_val()` to quantized parameters to get/clear this high + precision tensor. The purpose is that users can use this high-precision copy + to initialize master weights, avoiding the loss of precision that can occur when + using quantized parameters directly. Note that after the master weights are initialized, + users should call `clear_high_precision_init_val()` to release this CPU memory. + + This functionality is *EXPERIMENTAL*. + """ + + _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS + _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE + _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL + FP8GlobalStateManager.FP8_PARAMETERS = enabled + FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val + try: + yield + finally: + FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val + + +@contextmanager +def fp8_autocast( + enabled: bool = True, + calibrating: bool = False, + fp8_recipe: Optional[Recipe] = None, + fp8_group: Optional[dist_group_type] = None, + _graph: bool = False, +) -> None: + """ + .. warning:: + + fp8_autocast is deprecated and will be removed in a future release. + Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead. + + """ + + warnings.warn( + "fp8_autocast is deprecated and will be removed in a future release. " + "Use autocast(enabled=..., calibrating=..., recipe=..., group=..., _graph=...) instead.", + category=DeprecationWarning, + stacklevel=2, + ) + + # Call new implementation. + with autocast( + enabled=enabled, + calibrating=calibrating, + recipe=fp8_recipe, + amax_reduction_group=fp8_group, + _graph=_graph, + ): + yield + + +@contextmanager +def autocast( + enabled: bool = True, + calibrating: bool = False, + recipe: Optional["Recipe"] = None, + amax_reduction_group: Optional["dist_group_type"] = None, + _graph: bool = False, +) -> None: + """ + Context manager for quantization schemes like FP8 or FP4. + + .. code-block:: python + + with autocast(enabled=True): + out = model(inp) + + .. note:: + + Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors + with shapes where both dimensions are divisible by 16. In terms of the input to the full + Transformer network, this typically requires padding sequence length to be multiple of 16. + + .. note:: + + When :attr:`recipe.reduce_amax==True`, any module must not be invoked more than once + inside a single `autocast` region. This is unsupported behavior because the amax + reduction is handled during the exit of the `autocast` context. Calling the same + module more than once inside an `autocast` region overrides the amax tensors + before reduction can occur. + + Parameters + ---------- + enabled: bool, default = `True` + whether or not to enable low precision quantization (FP8/FP4). + calibrating: bool, default = `False` + calibration mode allows collecting statistics such as amax and scale + data of quantized tensors even when executing without quantization enabled. + This is useful for saving an inference ready checkpoint while training + using a higher precision. + recipe: recipe.Recipe, default = `None` + recipe used for low precision quantization. + amax_reduction_group: torch._C._distributed_c10d.ProcessGroup, default = `None` + distributed group over which amaxes for the quantized tensors + are reduced at the end of each training step. + """ + + if enabled: + check_recipe_support(recipe) + + # Save current state so we always restore it on exit. + fp8_state = FP8GlobalStateManager.get_autocast_state() + + FP8GlobalStateManager.autocast_enter( + enabled=enabled, + calibrating=calibrating, + fp8_recipe=recipe, + fp8_group=amax_reduction_group, + _graph=_graph, + ) + try: + yield + finally: + FP8GlobalStateManager.set_autocast_state(fp8_state) + FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + + +def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: + """Update amax history and set next amax to zero.""" + if amax_history.shape[0] > 1: + new_amax_history = torch.roll(amax_history, -1, 0) + amax_history.copy_(new_amax_history) + amax_history[0].fill_(0.0) + return amax_history + + +@torch.jit.script +def _default_get_amax_and_update_history( + amax_history: torch.Tensor, + amax_compute_algo: str, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Default function to obtain amax from history.""" + if amax_compute_algo == "max": + amax = torch.max(amax_history, dim=0).values + else: # amax_compute_algo == "most_recent" + amax = amax_history[0].clone() + + amax_history = _update_amax_history(amax_history) + return amax_history, amax + + +@jit_fuser +def _default_sf_compute( + amax: torch.Tensor, + scale: torch.Tensor, + fp8_max: float, + margin: int, + _fp32_max: float = torch.finfo(torch.float32).max, # finfo not available in jitter +) -> torch.Tensor: + """Default function to convert amax to scaling factor. + Computing the scaling factor requires consideration of the following scenarios: + 1. amax == 0: + No action is possible, set scale to the previous scale (or 1). + 2. 0 < amax < tiny_amax + The amax is too tiny that the scale becomes infinite in FP32. + Set scale = FP32_max + 3. tiny_amax <= amax < FP32_max: + Set scale = FP8_max (or scaled_max) / amax + 4. When amax == inf or amax == nan: + No action is possible, set scale to the previous scale (or 1). + """ + sf = (fp8_max / amax) / (2**margin) + sf = torch.where(amax > 0.0, sf, scale) + sf = torch.where(torch.isfinite(amax), sf, scale) + sf = torch.where(torch.isinf(sf), torch.full_like(sf, _fp32_max), sf) + scale.copy_(sf) + return scale + + +def _compute_amax_and_update_history( + amax_history: torch.Tensor, + amax_compute_algo: Union[Callable, str], +) -> Tuple[torch.Tensor, torch.Tensor]: + """Obtain the amax from the history.""" + + if callable(amax_compute_algo): + amax = amax_compute_algo(amax_history) + amax_history = _update_amax_history(amax_history) + return amax_history, amax + return _default_get_amax_and_update_history( + amax_history, + amax_compute_algo, + ) + + +def _compute_scaling_factor( + amax: torch.Tensor, + scale: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, +) -> torch.Tensor: + """Convert amax to scaling factor.""" + + if recipe.scaling_factor_compute_algo is None: + return _default_sf_compute( + amax, + scale, + fp8_max, + recipe.margin, + ) + return recipe.scaling_factor_compute_algo(amax, scale, fp8_max, recipe) + + +def _amax_and_scale_update( + amax_history: torch.Tensor, + scale: torch.Tensor, + fp8_max: float, + recipe: DelayedScaling, +) -> None: + """Updates FP8 meta tensors.""" + new_amax_history, amax = _compute_amax_and_update_history( + amax_history, + recipe.amax_compute_algo, + ) + new_scale = _compute_scaling_factor(amax, scale, fp8_max, recipe) + scale.copy_(new_scale) + amax_history.copy_(new_amax_history) + + +def split_and_copy( + buffer: torch.Tensor, + outputs: List[torch.Tensor], + chunk_sizes: List[int], +) -> None: + """Split `buffer` by `chunk_sizes` and copy into `outputs`.""" + splits = buffer.split(chunk_sizes) + torch._foreach_copy_(outputs, splits) + + +class RecipeState(abc.ABC): + """Configuration and state for a quantization recipe. + + This is a builder class for quantizers, which are in turn builder + classes for quantized tensors. + + This class may pack together the state for multiple quantizers, + which is helpful for applying fused kernels with less overhead. + + """ + + @staticmethod + def create( + recipe: Recipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> RecipeState: + """Factory method to create the state for a quantization recipe + + Parameters + ---------- + recipe: Recipe + Quantization recipe. + mode: {"forward", "backward"} + Training stage where quantization will be performed. + num_quantizers: int, default = 1 + Number of quantizers to create state for. + device: torch.device, default = default CUDA device + Device for quantized tensors. + + Returns + ------- + RecipeState: + Quantization recipe state. + + """ + + cls = None + if recipe.delayed(): + cls = DelayedScalingRecipeState + elif recipe.mxfp8(): + cls = MXFP8BlockScalingRecipeState + elif recipe.float8_current_scaling(): + cls = Float8CurrentScalingRecipeState + elif recipe.float8_block_scaling(): + cls = Float8BlockScalingRecipeState + elif recipe.nvfp4(): + cls = NVFP4BlockScalingRecipeState + elif recipe.custom(): + cls = CustomRecipeState + else: + raise ValueError(f"{recipe.__class__.__name__} is not supported") + return cls( + recipe, + mode=mode, + num_quantizers=num_quantizers, + device=device, + ) + + @abc.abstractmethod + def make_quantizers(self) -> list: + """Convert recipe state to quantizers. + + Quantizers are builder classes for quantized tensors. They are + typically used to convert a high-precision tensor (e.g. in + FP32 or BF16) into a quantized tensor (e.g. in FP8). + + """ + + +class DelayedScalingRecipeState(RecipeState): + """State for FP8 quantization with per-tensor delayed scaling. + + Delayed scaling recipe requires a scaling factor (applied when + casting to FP8) and a history of max-abs values ("amax") from + recent FP8 casts for updating the scaling factor. The scale update + is handled externally by `FP8GlobalStateManager`. + + """ + + recipe: DelayedScaling + mode: str + dtype: tex.DType + scale: torch.Tensor + amax_history: torch.Tensor + + def __init__( + self, + recipe: DelayedScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.scale = torch.ones(num_quantizers, dtype=torch.float32, device=device) + self.amax_history = torch.zeros( + recipe.amax_history_len, + num_quantizers, + dtype=torch.float32, + device=device, + ) + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_tensor import Float8Quantizer + + return [ + Float8Quantizer(self.scale[i], self.amax_history[0][i].reshape((1,)), self.dtype) + for i in range(self.num_quantizers) + ] + + +class Float8CurrentScalingRecipeState(RecipeState): + """Configuration for Per-tensor current scaling quantization. + + Per-tensor current quantization does not require state. + + """ + + recipe: Float8CurrentScaling + mode: str + dtype: tex.DType + device: torch.device + + def __init__( + self, + recipe: Float8CurrentScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + from .tensor.float8_tensor import Float8CurrentScalingQuantizer + + return [ + Float8CurrentScalingQuantizer( + self.dtype, device=self.device, force_pow_2_scales=self.recipe.use_power_2_scales + ) + for i in range(self.num_quantizers) + ] + + +class MXFP8BlockScalingRecipeState(RecipeState): + """Configuration for MXFP8 quantization. + + MXFP8 quantization does not require state. + + """ + + recipe: MXFP8BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: MXFP8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp8_te_dtype(recipe, mode == "forward") + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.mxfp8_tensor import MXFP8Quantizer + + return [MXFP8Quantizer(self.dtype) for i in range(self.num_quantizers)] + + +class Float8BlockScalingRecipeState(RecipeState): + """Configuration for Float8BlockScaling quantization. + + Float8BlockScaling quantization does not require state, + but different quantizers use different modes. + """ + + recipe: Float8BlockScaling + mode: str + qx_dtype: tex.DType + qw_dtype: tex.DType + qgrad_dtype: tex.DType + + def __init__( + self, + recipe: Float8BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.qx_dtype = get_fp8_te_dtype(recipe, True) + self.qw_dtype = get_fp8_te_dtype(recipe, True) + self.qgrad_dtype = get_fp8_te_dtype(recipe, False) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + self.device = device + + def make_quantizers(self) -> list: + # TODO(ksivamani); Find better design for this, adding here to avoid circular import. + from .tensor.float8_blockwise_tensor import Float8BlockQuantizer + + if self.mode == "forward": + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward, and doesn't play nicely with QuantizeOp, + # which is not associated with a GEMM. + assert self.num_quantizers % 3 == 0 # x, w, output per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qw_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, + block_scaling_dim=self.recipe.w_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qx_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, + block_scaling_dim=self.recipe.x_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 3) + ] + ) + ) + + assert self.mode == "backward", f"Unexpected mode {self.mode}" + assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm + return list( + itertools.chain.from_iterable( + [ + [ + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + Float8BlockQuantizer( + fp8_dtype=self.qgrad_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, + force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, + block_scaling_dim=self.recipe.grad_block_scaling_dim, + ), + ] + for _ in range(self.num_quantizers // 2) + ] + ) + ) + + +class NVFP4BlockScalingRecipeState(RecipeState): + """Configuration for NVFP4 quantization. + + NVFP4 quantization does not require state. + + """ + + recipe: NVFP4BlockScaling + mode: str + dtype: tex.DType + + def __init__( + self, + recipe: NVFP4BlockScaling, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + self.dtype = get_fp4_te_dtype(recipe) + + # Allocate buffers + if device is None: + device = torch.device("cuda") + + def make_quantizers(self) -> list: + from .tensor.nvfp4_tensor import NVFP4Quantizer + + # The index convention (coming from base.py set_meta_tensor) + # is somewhat awkward. It assumes forward quantizers are + # ordered [input, weight, output, ...] and backward quantizers + # are ordered [grad_output, grad_input, ...]. This doesn't + # play nicely with fusible ops: Linear op doesn't own output + # or grad input quantizers, Quantize op only owns input and + # grad output quantizers. + + if self.mode == "forward": + + def _make_quantizer(idx: int) -> NVFP4Quantizer: + qparams = ( + self.recipe.fp4_quant_fwd_weight + if idx % 3 == 1 + else self.recipe.fp4_quant_fwd_inp + ) + return NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + ) + + return [_make_quantizer(idx) for idx in range(self.num_quantizers)] + + if self.mode == "backward": + return [ + NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, + with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, + stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + ) + for _ in range(self.num_quantizers) + ] + + raise RuntimeError(f"Unexpected recipe mode ({self.mode})") + + +class CustomRecipeState(RecipeState): + """State for CustomRecipe: produce quantizers per tensor.""" + + recipe: CustomRecipe + mode: str + num_quantizers: int + device: Optional[torch.device] + + def __init__( + self, + recipe: CustomRecipe, + *, + mode: str, + num_quantizers: int = 1, + device: Optional[torch.device] = None, + ) -> None: + self.recipe = recipe + self.mode = mode + self.num_quantizers = num_quantizers + if device is None: + device = torch.device("cuda") + self.device = device + + if getattr(recipe, "qfactory", None) is None: + raise ValueError("CustomRecipe requires `qfactory`.") + + def make_quantizers(self) -> list: + qfactory = self.recipe.qfactory + out = [] + + # TODO(negvet): make_quantizers() should take roles from the operation + # Hardcode linear-specific roles for now + roles: List[str] + if self.mode == "forward": + roles = [ + ("linear_input", "linear_weight", "linear_output")[i % 3] + for i in range(self.num_quantizers) + ] + elif self.mode == "backward": + roles = [ + ("linear_grad_output", "linear_grad_input")[i % 2] + for i in range(self.num_quantizers) + ] + else: + roles = ["unknown"] * self.num_quantizers + + for i in range(self.num_quantizers): + # Get quantizer from the user defined factory + quantizer = qfactory(roles[i]) + out.append(quantizer) + return out diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py similarity index 69% rename from transformer_engine/pytorch/tensor/quantized_tensor.py rename to transformer_engine/pytorch/quantized_tensor.py index 2f634f399..dd6a7ebc5 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -4,7 +4,7 @@ # # See LICENSE for license information. -"""Tensor with quantized data""" +"""Pure Python base classes for quantization.""" from __future__ import annotations from torch.utils.cpp_extension import IS_HIP_EXTENSION @@ -13,16 +13,26 @@ import abc import copy import warnings +import math import torch from torch.utils._pytree import tree_map -import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor._quantization_helpers import ( + _QuantizeFunc, + _IdentityFunc, + _stride_from_shape, +) +_quantized_tensor_cpu_supported_ops = ( + torch.ops.aten.empty_like.default, + torch.ops.aten.copy_.default, +) -class QuantizedTensorBase: - r"""Base class for all *TensorBase classes. + +class QuantizedTensorStorage: + r"""Base class for all *TensorStorage classes. This class (and its subclasses) are optimization for when the full QuantizedTensor is not needed (when it is fully @@ -30,12 +40,12 @@ class QuantizedTensorBase: PyTorch's autograd). When creating a new tensor type X one should create both - XTensorBase class inheriting from QuantizedTensorBase and - XTensor inheriting from XTensorBase and QuantizedTensor. - XTensorBase should contain all data members needed to + XTensorStorage class inheriting from QuantizedTensorStorage and + XTensor inheriting from XTensorStorage and QuantizedTensor. + XTensorStorage should contain all data members needed to implement the functionality of the tensor, while XTensor should only implement the functionality needed - to behave like regular torch.Tensor (liek __torch_dispatch__).""" + to behave like regular torch.Tensor (like __torch_dispatch__).""" _quantizer: Optional[Quantizer] @@ -63,7 +73,13 @@ def update_usage( f"{self.__class__.__name__} class does not implement update_usage function" ) - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement get_usages function" + ) + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement prepare_for_saving function" @@ -77,6 +93,30 @@ def restore_from_saved( f"{self.__class__.__name__} class does not implement restore_from_saved function" ) + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return self._build_default_quantizer() + + def _build_default_quantizer(self) -> Quantizer: + """Build default quantizer for the tensor""" + raise ValueError( + f"{self.__class__.__name__} has no quantizer " + "and no default quantizer is available defined in the subclass." + ) + + def quantize_( + self, tensor: torch.Tensor, *, noop_flag: Optional[torch.Tensor] = None + ) -> QuantizedTensor: + """Quantize tensor in-place""" + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + def update_quantizer(self, quantizer: Quantizer): """Update quantizer for the tensor""" if self._quantizer is None: @@ -87,13 +127,13 @@ def update_quantizer(self, quantizer: Quantizer): def prepare_for_saving( - *tensors: Union[torch.Tensor, QuantizedTensorBase], + *tensors: Union[torch.Tensor, QuantizedTensorStorage], ) -> Tuple[ - list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorBase]] + list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]] ]: """Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save - the internal TensorBase types too.""" + the internal *TensorStorage types too.""" tensor_list, tensor_objects_list = [], [] for tensor in tensors: @@ -104,16 +144,17 @@ def prepare_for_saving( t, t_obj = tensor.prepare_for_saving() tensor_list.extend(t) tensor_objects_list.append(t_obj) + return tensor_list, tensor_objects_list def restore_from_saved( - tensors: list[Optional[Union[torch.Tensor, QuantizedTensorBase]]], + tensors: list[Optional[Union[torch.Tensor, QuantizedTensorStorage]]], saved_tensors: list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], return_saved_tensors: bool = False, ) -> ( - list[Optional[torch.Tensor | QuantizedTensorBase]] - | tuple[list[Optional[torch.Tensor | QuantizedTensorBase]], list[Optional[torch.Tensor]]] + list[Optional[torch.Tensor | QuantizedTensorStorage]] + | tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]] ): """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] @@ -182,7 +223,6 @@ def __repr__(self): ")" ) - @abc.abstractmethod def update_quantized( self, src: torch.Tensor, @@ -191,6 +231,9 @@ def update_quantized( noop_flag: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Quantize tensor in-place""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement update_quantized" + ) def quantize( self, @@ -203,8 +246,14 @@ def quantize( if out is not None: return self.update_quantized(tensor, out) if (not self.internal) and torch.is_grad_enabled(): - return _QuantizeFunc.apply(tensor, self) - return _QuantizeFunc.forward(None, tensor, self) + return _QuantizeFunc.apply(tensor, self.quantize_impl) + return _QuantizeFunc.forward(None, tensor, self.quantize_impl) + + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement quantize_impl function" + ) def multi_quantize(self, list_of_tensors): """Quantize multiple tensors""" @@ -217,7 +266,6 @@ def __call__(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor""" return self.quantize(tensor) - @abc.abstractmethod def make_empty( self, shape: Iterable[int], @@ -226,8 +274,11 @@ def make_empty( device: Optional[torch.device] = None, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement make_empty function, " + "required for construction of unintialized quantized tensor" + ) - @abc.abstractmethod def calibrate(self, tensor: torch.Tensor) -> None: """Calibrate quantizer state @@ -256,90 +307,36 @@ def copy(self) -> Quantizer: def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Symbolic function for ONNX export""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement onnx_quantize" + ) def onnx_dequantize(self, tensor) -> torch.Tensor: """Symbolic function for ONNX export""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement onnx_dequantize" + ) - @abc.abstractmethod def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Returns recipe class that is compatible with this quantizer""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement _get_compatible_recipe" + ) def supports_only_rowwise_all_gather(self) -> bool: """Returns True if the quantizer supports only rowwise all-gather""" return False + def is_quantizable(self, inp: torch.Tensor) -> bool: # pylint: disable=unused-argument + """Returns whether or not given tensor can be quantized""" + return True -class _QuantizeFunc(torch.autograd.Function): - """Cast to FP8 from other dtype""" - - @staticmethod - def forward( - _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: torch.Tensor, - quantizer: Quantizer, - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton - use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) - quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize - return quantize_func(tensor, quantizer) - else: - return tex.quantize(tensor, quantizer) - - @staticmethod - def backward( - _ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor # unused - ) -> Tuple[Optional[torch.Tensor], ...]: - # pylint: disable=missing-function-docstring - # Assume that we want gradients in full precision - return grad, None - - -class _IdentityFunc(torch.autograd.Function): - """Identity function - - If constructor keyword-arguments are provided, then construct a - new Float8Tensor using the provided tensor's attributes. - - """ - - @staticmethod - def forward( - ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None - ) -> QuantizedTensor: - # pylint: disable=missing-function-docstring - - # Return input tensor if constructor kwargs are not provided - if init_kwargs is None: - return tensor.detach() - - # Construct new tensor if constructor kwargs are provided - ctx.input_dtype = tensor.dtype - kwargs = tensor.get_metadata() - for key, val in init_kwargs.items(): - kwargs[key] = val - return type(tensor)(tensor.shape, tensor.dtype, **kwargs) - - @staticmethod - def backward(ctx, grad_output): - # pylint: disable=missing-function-docstring - grad_input = grad_output - if grad_input.dtype == ctx.input_dtype: - grad_input = grad_input.detach() - else: - grad_input = grad_input.to(ctx.input_dtype) - return grad_input, None - - -def _stride_from_shape(shape: list[int]): - if len(shape) == 0: - return [] - rstride = [1] - for d in reversed(shape[1:]): - rstride.append(rstride[-1] * d) - return list(reversed(rstride)) - + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the quantizer""" + return { + "rowwise": self.rowwise_usage, + "columnwise": self.columnwise_usage, + } class QuantizedTensor(torch.Tensor): """Abstract base class for tensor with quantized data @@ -351,7 +348,14 @@ class QuantizedTensor(torch.Tensor): """ - def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: bool = False): + def __new__( + cls, + shape: Iterable[int], + dtype: torch.dtype, + *, + requires_grad: bool = False, + device: Optional[torch.device] = None, + ): # We are assuming only contiguous tensors stride = _stride_from_shape(shape) instance = torch.Tensor._make_wrapper_subclass( @@ -362,7 +366,7 @@ def __new__(cls, shape: Iterable[int], dtype: torch.dtype, *, requires_grad: boo dtype=dtype, layout=torch.strided, requires_grad=requires_grad, - device=torch.cuda.current_device(), + device=torch.cuda.current_device() if device is None else device, ) return instance @@ -392,6 +396,9 @@ def detach(self) -> QuantizedTensor: def clear(self): """Deallocate this tensor's memory. Typically not needed and must be used carefully""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement clear function" + ) def __repr__(self, *, tensor_contents=None) -> str: return f"{self.__class__.__name__}(data={self.dequantize(dtype=self.dtype)})" @@ -433,6 +440,26 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.copy_.default: dst = args[0] src = args[1] + if ( + isinstance(dst, QuantizedTensor) + and isinstance(src, QuantizedTensor) + and type(dst._quantizer) is type(src._quantizer) + and set(src.get_usages().keys()) == set(dst.get_usages().keys()) + and all( + src.get_usages()[usage] == dst.get_usages()[usage] + for usage in src.get_usages().keys() + ) + ): + + dst_tensors, dst_tensor_obj = dst.prepare_for_saving() + src_tensors, src_tensor_obj = src.prepare_for_saving() + for dst_tensor, src_tensor in zip(dst_tensors, src_tensors): + if dst_tensor is not None: + dst_tensor.copy_(src_tensor, *args[2:], **kwargs) + dst_tensor_obj.restore_from_saved(dst_tensors) + src_tensor_obj.restore_from_saved(src_tensors) + return None + if isinstance(dst, QuantizedTensor): dst.quantize_(src) else: @@ -445,6 +472,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == torch.ops.aten.view.default: raise NotImplementedError("{cls.__name__} class does not support tensor views") + # Empty like op + if func == torch.ops.aten.empty_like.default: + tensor = args[0] + device = kwargs.get("device", tensor.device) + requires_grad = kwargs.get("requires_grad", tensor.requires_grad) + pin_memory = kwargs.get("pin_memory", False) + usage = tensor.get_usages() + quantizer_usage = tensor._quantizer.get_usages() + tensor._quantizer.set_usage(**usage) + out = tensor._quantizer.make_empty( + shape=tensor.shape, + dtype=tensor.dtype, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + tensor._quantizer.set_usage(**quantizer_usage) + return out + + if func == torch.ops.aten.numel.default: + tensor = args[0] + return math.prod(tensor.size()) + + if func == torch.ops.aten.is_pinned.default: + tensor = args[0] + for t in tensor.get_data_tensors(): + if t is not None: + return func(t) + return False # Or error out? + def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): return arg.dequantize(dtype=arg.dtype) @@ -459,6 +516,10 @@ def maybe_update_inplace(arg, new_arg, schema_arg): and schema_arg.alias_info.is_write ): arg.quantize_(new_arg) + elif isinstance(arg, list) and isinstance(new_arg, list): + # Recursively handle update for lists of tensors + for a, na in zip(arg, new_arg): + maybe_update_inplace(a, na, schema_arg) # In-place op: dequantize, perform op, and quantize if func._schema.is_mutable: @@ -485,6 +546,16 @@ def maybe_update_inplace(arg, new_arg, schema_arg): def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} + + def check_if_cpu(arg): + if isinstance(cls, QuantizedTensor) and arg.device.type == "cpu": + assert ( + func in _quantized_tensor_cpu_supported_ops + ), f"QuantizedTensor on CPU does not support this operation: {func}" + return arg + + args = tree_map(check_if_cpu, args) + # Do not force the QuantizedTensor type on the returned tensor return torch._C._disabled_torch_function_impl(func, types, args, kwargs) @@ -515,20 +586,16 @@ def make_like( shape: Optional[Iterable[int]] = None, dtype: Optional[torch.dtype] = None, requires_grad: bool = False, - data: Optional[torch.Tensor] = None, ) -> QuantizedTensor: """Create new quantized tensor By default, new tensor has the same attributes and underlying - data. + data. This function is intended to create view of tensors. """ - if shape is None: - shape = data.shape if data is not None else tensor.shape + shape = shape if shape is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() - if data is not None: - kwargs["data"] = data return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs) def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor: diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index e86873b12..12a87d4bd 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -47,8 +47,8 @@ from build_tools.build_ext import get_build_ext -from build_tools.utils import ( - rocm_build, copy_common_headers, copy_hipify_tools, clear_hipify_tools_copy ) +from build_tools.utils import rocm_build, copy_hipify_tools, clear_hipify_tools_copy +from build_tools.utils import copy_common_headers, min_python_version_str from build_tools.te_version import te_version from build_tools.pytorch import ( setup_pytorch_extension, @@ -150,14 +150,29 @@ def run(self): ) ] + # Setup version and requirements. + # Having the framework extension depend on the core lib allows + # us to detect CUDA version dynamically during compilation and + # choose the correct wheel for te core lib. + __version__ = te_version() + if not rocm_build(): + cuda_major_version = parse(torch.version.cuda).major + assert cuda_major_version in (12, 13), f"Unsupported cuda version {torch.version.cuda}." + te_core = f"transformer_engine_cu{cuda_major_version}=={__version__}" + install_requires = install_requirements() + [te_core] + else: + te_core = f"transformer_engine_rocm=={__version__}" + install_requires = install_requirements() + [te_core] + # Configure package setuptools.setup( name=PACKAGE_NAME, - version=te_version(), + version=__version__, description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, - install_requires=install_requirements(), + python_requires=f">={min_python_version_str()}", + install_requires=install_requires, tests_require=test_requirements(), ) if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")): diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py index 7fa12cc08..ada624a90 100644 --- a/transformer_engine/pytorch/tensor/__init__.py +++ b/transformer_engine/pytorch/tensor/__init__.py @@ -6,12 +6,42 @@ import torch -from .quantized_tensor import QuantizedTensor, Quantizer +from ..quantized_tensor import ( + QuantizedTensorStorage, + QuantizedTensor, + Quantizer, + prepare_for_saving, + restore_from_saved, +) +from .storage.float8_tensor_storage import Float8TensorStorage +from .storage.mxfp8_tensor_storage import MXFP8TensorStorage +from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .storage.nvfp4_tensor_storage import NVFP4TensorStorage +from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer +from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer +from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer +from .nvfp4_tensor import NVFP4Tensor, NVFP4Quantizer from .utils import cast_master_weights_to_fp8, replace_raw_data __all__ = [ - "QuantizedTensor", "Quantizer", + "Float8Quantizer", + "Float8CurrentScalingQuantizer", + "MXFP8Quantizer", + "Float8BlockQuantizer", + "NVFP4Quantizer", + "QuantizedTensorStorage", + "Float8TensorStorage", + "MXFP8TensorStorage", + "Float8BlockwiseQTensorStorage", + "NVFP4TensorStorage", + "QuantizedTensor", + "Float8Tensor", + "MXFP8Tensor", + "Float8BlockwiseQTensor", + "NVFP4Tensor", + "prepare_for_saving", + "restore_from_saved", ] @@ -48,21 +78,16 @@ def get_all_tensor_types(): """ Get all tensor-like types that can be used in TE. """ - from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8TensorBase - from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8TensorBase - from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( - Float8BlockwiseQTensor, - Float8BlockwiseQTensorBase, - ) - all_tensor_types = [ torch.Tensor, torch.nn.Parameter, Float8Tensor, - Float8TensorBase, + Float8TensorStorage, MXFP8Tensor, - MXFP8TensorBase, + MXFP8TensorStorage, Float8BlockwiseQTensor, - Float8BlockwiseQTensorBase, + Float8BlockwiseQTensorStorage, + NVFP4Tensor, + NVFP4TensorStorage, ] return all_tensor_types diff --git a/transformer_engine/pytorch/tensor/_quantization_helpers.py b/transformer_engine/pytorch/tensor/_quantization_helpers.py new file mode 100644 index 000000000..55fc4785d --- /dev/null +++ b/transformer_engine/pytorch/tensor/_quantization_helpers.py @@ -0,0 +1,85 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Private helper functions and classes for quantized tensor implementations. + +This module contains internal autograd functions and utilities that support +the quantization machinery. +""" + +from __future__ import annotations +from typing import Callable, Optional, Tuple, Any, Dict, TYPE_CHECKING +import torch + +if TYPE_CHECKING: + from transformer_engine.pytorch.quantized_tensor import QuantizedTensor + + +class _QuantizeFunc(torch.autograd.Function): + """Quantize tensor""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: torch.Tensor, + quantize_impl: Callable, + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + # TODO: bring back triton based quantization + return quantize_impl(tensor) + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class _IdentityFunc(torch.autograd.Function): + """Identity function + + If constructor keyword-arguments are provided, then construct a + new Float8Tensor using the provided tensor's attributes. + + """ + + @staticmethod + def forward( + ctx, tensor: QuantizedTensor, init_kwargs: Optional[Dict[str, Any]] = None + ) -> QuantizedTensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if constructor kwargs are not provided + if init_kwargs is None: + return tensor.detach() + + # Construct new tensor if constructor kwargs are provided + ctx.input_dtype = tensor.dtype + kwargs = tensor.get_metadata() + for key, val in init_kwargs.items(): + kwargs[key] = val + return type(tensor)(tensor.shape, tensor.dtype, **kwargs) + + @staticmethod + def backward(ctx, grad_output): + # pylint: disable=missing-function-docstring + grad_input = grad_output + if grad_input.dtype == ctx.input_dtype: + grad_input = grad_input.detach() + else: + grad_input = grad_input.to(ctx.input_dtype) + return grad_input, None + + +def _stride_from_shape(shape: list[int]): + """Calculate stride from shape for contiguous tensors""" + if len(shape) == 0: + return [] + rstride = [1] + for d in reversed(shape[1:]): + rstride.append(rstride[-1] * d) + return list(reversed(rstride)) diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index 0e41fc9c5..8440c14b7 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -13,8 +13,9 @@ from transformer_engine_torch import Float8BlockScaleTensorFormat from transformer_engine.common.recipe import Float8BlockScaling, Recipe -from ._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc from ..utils import devices_match, round_up_to_nearest_multiple aten = torch.ops.aten @@ -101,6 +102,10 @@ def update_quantized( dst._fp8_dtype = self.dtype return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: """Calculate the shape of the scaling tensor for blockwise quantization. @@ -209,6 +214,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8BlockwiseQTensor: """Construct quantized tensor with uninitialized data""" if device is None: @@ -224,12 +230,13 @@ def make_empty( data = None scale_inv = None if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) scale_shape = self.get_scale_shape(shape, columnwise=False) scale_inv = torch.empty( scale_shape, dtype=torch.float32, device=device, + pin_memory=pin_memory, ) # Allocate FP8 data transpose if needed @@ -237,13 +244,17 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty( - self.get_columnwise_shape(shape), dtype=torch.uint8, device=device + self.get_columnwise_shape(shape), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( columnwise_scale_shape, dtype=torch.float32, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -270,7 +281,7 @@ def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8BlockScaling -class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): +class Float8BlockwiseQTensor(Float8BlockwiseQTensorStorage, QuantizedTensor): """Tensor class with FP8 data quantized via NxN blocks or 1xN blocks. The tensor presents as having a standard, higher-precision dtype, @@ -295,7 +306,7 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor): holds configuration about quantization and dequantization modes. """ - # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a Float8BlockwiseQTensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, @@ -334,15 +345,6 @@ def __repr__(self, *, tensor_contents=None): f" data_format={self._data_format}" ) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - assert self._quantizer is not None - return self._quantizer - def quantize_( self, tensor: torch.Tensor, @@ -361,8 +363,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize()) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """ @@ -405,6 +406,21 @@ def reshape(self, *shape: Tuple[int]) -> Float8BlockwiseQTensor: # pylint: disable=missing-function-docstring return _ReshapeFunc.apply(self, shape) + def untyped_storage(self) -> torch.UntypedStorage: + """Return the underlying UntypedStorage of the FP8 data. + + Note that FP8 block-scaled tensor may involve multiple + buffers: row-wise FP8 data, row-wise scales, column-wise FP8 + data, column-wise scales. The UntypedStorage of the row-wise + FP8 data is returned if it exists, and otherwise the + UntypedStorage of the column-wise FP8 data. + + """ + data = self._rowwise_data if self._rowwise_data is not None else self._columnwise_data + if data is not None: + return data.untyped_storage() + return torch.UntypedStorage(0, device=self.device) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): @@ -429,6 +445,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) return Float8BlockwiseQTensor.make_like(tensor) + # record stream op + if func == torch.ops.aten.record_stream.default: + qt, stream = args + for t in ( + qt._rowwise_data, + qt._columnwise_data, + qt._rowwise_scale_inv, + qt._columnwise_scale_inv, + ): + if t is not None and t.is_cuda: + t.record_stream(stream) + return None + # Default case return super().__torch_dispatch__(func, types, args, kwargs) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index a0a17d1a1..316733e31 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -1,27 +1,30 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. """Tensor class with FP8 data""" from __future__ import annotations -import os -from typing import Optional, Tuple, Iterable, Union -import warnings -from torch.utils.cpp_extension import IS_HIP_EXTENSION +from typing import Any, Optional, Tuple, Iterable, Union +import warnings import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe from ..utils import canonicalize_process_group, devices_match -from ._internal.float8_tensor_base import Float8TensorBase, _FromFloat8Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc from ..constants import dist_group_type + +from torch.utils.cpp_extension import IS_HIP_EXTENSION if IS_HIP_EXTENSION: from ..triton_kernels.cast import te_quantize_triton + import os aten = torch.ops.aten @@ -100,6 +103,16 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) + def make_empty( self, shape: Iterable[int], @@ -107,6 +120,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8Tensor: # Canonicalize tensor attributes @@ -114,16 +128,19 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - transpose_shape = [data.size(-1)] + list(data.shape[:-1]) + transpose_shape = [shape[-1]] + list(shape[:-1]) data_transpose = torch.empty( transpose_shape, dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -131,7 +148,7 @@ def make_empty( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -158,7 +175,7 @@ def create_tensor_from_data( torch.float8_e5m2fnuz, ] if internal: - return Float8TensorBase( + return Float8TensorStorage( data=data, fp8_scale_inv=1 / self.scale, fp8_dtype=self.dtype, @@ -226,6 +243,8 @@ class Float8CurrentScalingQuantizer(Quantizer): amax: torch.Tensor """FP8 datatype""" dtype: TE_DType + """amax update options""" + use_existing_amax: bool """amax reduction options""" with_amax_reduction: bool amax_reduction_group: Optional[dist_group_type] @@ -240,6 +259,7 @@ def __init__( *, rowwise: bool = True, columnwise: bool = True, + use_existing_amax: bool = False, with_amax_reduction: bool = False, amax_reduction_group: Optional[dist_group_type] = None, force_pow_2_scales: bool = False, @@ -249,6 +269,7 @@ def __init__( self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype + self.use_existing_amax = use_existing_amax self.with_amax_reduction = with_amax_reduction self.amax_reduction_group = amax_reduction_group self.force_pow_2_scales = force_pow_2_scales @@ -283,6 +304,16 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) + def make_empty( self, shape: Iterable[int], @@ -290,6 +321,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> Float8Tensor: # Canonicalize tensor attributes @@ -297,25 +329,26 @@ def make_empty( device = torch.device("cuda") # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) + data = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [shape[-1]] + list(shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) - # Construct FP8 tensor return Float8Tensor( shape=shape, dtype=dtype, data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device), + fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), fp8_dtype=self.dtype, requires_grad=requires_grad, data_transpose=data_transpose, @@ -345,7 +378,7 @@ def create_tensor_from_data( torch.float8_e5m2fnuz, ] if internal: - return Float8TensorBase( + return Float8TensorStorage( data=data, fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=data.device), fp8_dtype=self.dtype, @@ -400,7 +433,7 @@ def supports_only_rowwise_all_gather(self) -> bool: return True -class Float8Tensor(Float8TensorBase, QuantizedTensor): +class Float8Tensor(Float8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -455,19 +488,6 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: return _FromFloat8Func.apply(self, dtype) return _FromFloat8Func.forward(None, self, dtype) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - if self._quantizer is not None: - return self._quantizer - # Now the quantizer for Float8Tensor can be not just Float8Quantizer (delayed scaling) - raise ValueError( - "Float8Tensor's quantizer is None, cannot get a quantizer from Float8Tensor variable" - ) - def quantize_( self, tensor: torch.Tensor, @@ -486,8 +506,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize(), noop_flag=noop_flag) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def detach(self) -> Float8Tensor: # pylint: disable=missing-function-docstring @@ -551,9 +570,36 @@ def remove_caches(self) -> None: self._transpose = None @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs=None): + def make_like( + cls, + tensor: QuantizedTensor, + *, + shape: Optional[Iterable[int]] = None, + dtype: Optional[torch.dtype] = None, + requires_grad: bool = False, + data: Optional[torch.Tensor] = None, + data_transpose: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + """Create new quantized tensor + + By default, new tensor has the same attributes and underlying + data. - # View op + """ + if shape is None and data is not None: + shape = data.shape + new_tensor = super().make_like( + tensor, shape=shape, dtype=dtype, requires_grad=requires_grad + ) + if data is not None: + new_tensor._data = data + if data_transpose is not None: + new_tensor._transpose = data_transpose + new_tensor._transpose_invalid = False + return new_tensor + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): if func == aten.view.default: tensor = args[0] data = tensor._data @@ -572,6 +618,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): or out_transpose_shape[1:] != out_shape[:-1] ): out_transpose = None + else: + view_shape_for_transpose = [out_shape[-1]] + list(out_shape[:-1]) + out_transpose = out_transpose.view(*view_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, @@ -604,11 +653,37 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [ - Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) - for split_tensor in func_out + t_func_out = [None] * len(func_out) + # Compute corresponding split of the transpose cache if available + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + ndim = data.dim() + # Figure out the original split dim + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + # Dimension along which transpose needs to be split + t_dim = 0 if dim_to_split == ndim - 1 else dim_to_split + 1 + t_func_out = transpose.__torch_dispatch__( + func, + types, + [transpose, args[1], t_dim], + kwargs, + ) + outs = [ + Float8Tensor.make_like( + tensor, + data=split_tensor, + data_transpose=split_transpose_tensor, + shape=split_tensor.shape, + ) + for split_tensor, split_transpose_tensor in zip(func_out, t_func_out) ] + return outs + if func == aten.new_zeros.default: + # create fresh new tensor with zeros. tensor = args[0] data = tensor._data func_out = data.__torch_dispatch__( @@ -617,29 +692,83 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + t_shape = [size[-1]] + list(size[:-1]) + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_shape] + list(args[2:]), + kwargs, + ) + # deep copy the scale inverse tensor and quantizer as well. + scale_inv = tensor._scale_inv.detach().clone() + quantizer = tensor._quantizer.copy() + out_tensor = Float8Tensor( + data=func_out, + shape=func_out.shape, + dtype=tensor.dtype, + fp8_dtype=tensor._fp8_dtype, + fp8_scale_inv=scale_inv, + data_transpose=func_transposed_out, + quantizer=quantizer, + ) + return out_tensor + if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data + # Apply as_strided to the primary uint8 data func_out = data.__torch_dispatch__( func, types, [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) + func_transposed_out = None + if tensor._transpose is not None and not tensor._transpose_invalid: + transpose = tensor._transpose + size = args[1] + stride = args[2] + if "storage_offset" in kwargs: + storage_offset = kwargs["storage_offset"] + else: + storage_offset = args[3] if len(args) > 3 else 0 + # Shape and strided needed for transpose matrix + t_size = [size[-1]] + list(size[:-1]) + t_stride = [stride[-1]] + list(stride[:-1]) + func_transposed_out = transpose.__torch_dispatch__( + func, + types, + [transpose, t_size, t_stride, storage_offset] + list(args[4:]), + kwargs, + ) + return Float8Tensor.make_like( + tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + ) + if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: return cls.clone(args[0]) + if func == torch.ops.aten.copy_.default: dst, src = args[0], args[1] # Just copy FP8 attrs if copying between Float8Tensors if isinstance(src, Float8Tensor) and isinstance(dst, Float8Tensor): - dst._data.copy_(src._data.detach()) - dst._scale_inv.copy_(src._scale_inv.view(dst._scale_inv.size())) - if src._transpose is not None or dst._transpose is not None: - dst._create_transpose() + if dst._data is not None: + dst._data.copy_(src._data.detach(), *args[2:], **kwargs) + if dst._scale_inv is not None: + dst._scale_inv.copy_( + src._scale_inv.view(dst._scale_inv.size()), *args[2:], **kwargs + ) + if dst._transpose is not None and not dst._transpose_invalid: + if not src._transpose_invalid: + dst._transpose.copy_(src._transpose, *args[2:], **kwargs) + else: + dst._create_transpose() return dst elif func in _ops_to_preserve_subclass_in_fsdp2: # Ops in the _ops_to_preserve_subclass_in_fsdp2 are recommened to return the same class instance to work fine with the torch fsdp2 @@ -649,9 +778,105 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) else: pass - return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor + (For us same as self.stride()) + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this FP8 tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + shareded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered.(In this case uint8 data tensor) + metadata: Tuple[Any]: Metadata needed for reconstructing the + Float8Tensor after all-gather. + """ + # pylint: disable=unused-argument + # Importing here to avoid circular imports + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + if isinstance(self._quantizer, Float8CurrentScalingQuantizer) and mesh is not None: + # When sharded weight is updated after reduce scattering the gradients in FSDP2, + # we need to do amax reduction across the mesh to make sure all weight shards are + # updated with same scale inverse. Setting the state below in the quantizer will make + # sure that updated Quantized weight tensor have same scale inverse across all shards. + self._quantizer.amax_reduction_group = mesh.get_group() + self._quantizer.with_amax_reduction = True + quantizer = self._quantizer.copy() # quantizer to be used for allgathered weights + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + # If weights are resharded after forward pass, then its enough to set the quantizer usages + # based on whether its forward or backward pass for the allgathered weights. + # If not resharded after forward pass, the same weights allgathered in forward + # are used again in backward and so we dont change the quantizer usages which might need + # both rowwise and columnwise usages. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + # In case of hopper/L40, only one of data/transpose is needed + # based on forward or backward pass. So setting the quantizer usages appropriately. + quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) + sharded_tensors = (self._data,) + metadata = (self._scale_inv, self._fp8_dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[Float8Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the Float8Tensor. + param_dtype (torch.dtype): high precision dtype of the Float8Tensor. + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + + Returns: + Tuple[Float8Tensor, Tuple[torch.Tensor, ...]]: Allgathered Float8Tensor and tuple of internal tensors + used by the Float8Tensor that was being computed after allgather. + """ + + (data,) = all_gather_outputs + (fp8_scale_inv, fp8_dtype, quantizer) = metadata + orig_shape = data.size() + # Quantizer has only columnwise usage set for backward pass + # In Blackwell+ architectures, transpose is not needed at all, + # even if columnwise usage is set. and is going to be handled + # internally in the update_usage method. + if out is not None: + out._data = data + else: + fp8_args = { + "shape": orig_shape, + "dtype": param_dtype, + "fp8_scale_inv": fp8_scale_inv, + "fp8_dtype": fp8_dtype, + "quantizer": quantizer, + "requires_grad": False, + "data": data, + } + out = Float8Tensor(**fp8_args) + + out.update_usage( + rowwise_usage=quantizer.rowwise_usage, + columnwise_usage=quantizer.columnwise_usage, + ) + return out, all_gather_outputs + @classmethod def _make_in_reduce_ex( cls, @@ -769,6 +994,9 @@ def forward( out_transpose_shape = out_transpose.size() if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: out_transpose = None + else: + view_shape_for_transpose = [shape[-1]] + list(shape[:-1]) + out_transpose = out_transpose.view(*view_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, @@ -813,6 +1041,9 @@ def forward( out_transpose_shape = out_transpose.size() if out_transpose_shape[0] != out_shape[-1] or out_transpose_shape[1:] != out_shape[:-1]: out_transpose = None + else: + reshape_shape_for_transpose = [shape[-1]] + list(shape[:-1]) + out_transpose = out_transpose.reshape(*reshape_shape_for_transpose) return Float8Tensor( shape=out_shape, dtype=tensor.dtype, diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 16b1568cb..f9ff4b77b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -4,27 +4,30 @@ # # See LICENSE for license information. -"""Tensor class with FP8 data""" +"""Tensor class with MXFP8 data""" from __future__ import annotations from collections.abc import Iterable import math -import os -from typing import Optional, Tuple, Union -from torch.utils.cpp_extension import IS_HIP_EXTENSION -import torch -if IS_HIP_EXTENSION: - from ..triton_kernels.cast import te_quantize_triton +from typing import Optional, Tuple, Union, Any +import warnings +import torch +from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE from ..utils import devices_match, round_up_to_nearest_multiple +from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc -from ._internal.mxfp8_tensor_base import MXFP8TensorBase, _FromMXFP8Func -from .quantized_tensor import QuantizedTensor, Quantizer, _IdentityFunc +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + import os + from ..triton_kernels.cast import te_quantize_triton aten = torch.ops.aten @@ -79,6 +82,16 @@ def update_quantized( return dst + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + if IS_HIP_EXTENSION: + from ..triton_kernels.cast import te_quantize_triton + use_cast_transpose_triton = bool( int(os.environ.get('NVTE_USE_CAST_TRANSPOSE_TRITON', '0')) ) + quantize_func = te_quantize_triton if use_cast_transpose_triton else tex.quantize + return quantize_func(tensor, self) + else: + return tex.quantize(tensor, self) + def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" if inp.ndim < 2: @@ -96,6 +109,7 @@ def make_empty( dtype: torch.dtype = torch.float32, device: Optional[torch.device] = None, requires_grad: bool = False, + pin_memory: bool = False, ) -> MXFP8Tensor: # Canonicalize tensor attributes @@ -111,26 +125,33 @@ def make_empty( ) # Allocate FP8 data - data = torch.empty(shape, dtype=torch.uint8, device=device) - # ROCm TE does not implement fuse padding zeros so use zero tensor here - scale_inv = torch.zeros( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - ) + data = None + scale_inv = None + if self.rowwise_usage: + data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) + # ROCm TE does not implement fuse padding zeros so use zero tensor here + scale_inv = torch.zeros( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, + ) # Allocate FP8 data transpose if needed columnwise_data = None columnwise_scale_inv = None if self.columnwise_usage: - columnwise_data = torch.empty_like(data) + columnwise_data = torch.empty( + shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) # ROCm TE does not implement fuse padding zeros so use zero tensor here columnwise_scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, device=device, + pin_memory=pin_memory, ) # Construct FP8 tensor @@ -175,14 +196,14 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: data, scale_inv = torch.ops.tex.mxfp8_quantize(tensor) return self.create_tensor_from_data(data, scale_inv, fake_dtype=torch.float32) - def onnx_dequantize(self, tensor: Union[MXFP8TensorBase, MXFP8Tensor]) -> torch.Tensor: + def onnx_dequantize(self, tensor: Union[MXFP8TensorStorage, MXFP8Tensor]) -> torch.Tensor: return torch.ops.tex.mxfp8_dequantize(tensor._rowwise_data, tensor._rowwise_scale_inv) def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return MXFP8BlockScaling -class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): +class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor): """Experimental tensor class with FP8 data The tensor presents as having a standard, higher-precision dtype, @@ -200,14 +221,13 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor): Reciprocal of the scaling factor applied when casting to FP8, i.e. the scaling factor that must be applied when casting from FP8 to higher - precision. Can be inferred from fp8_meta if - provided. + precision. dtype: torch.dtype, default = torch.float32 Nominal tensor datatype. """ - # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorBase with positional args, + # NOTE: We reorder the *args so that we can instantiate a MXFP8TensorStorage with positional args, # which significantly reduces the Pybind11 overhead when calling the constructor from C++. def __new__( cls, @@ -251,17 +271,9 @@ def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: return _FromMXFP8Func.apply(self, dtype) return _FromMXFP8Func.forward(None, self, dtype) - def _get_quantizer(self) -> Quantizer: - """Get builder for quantized tensor - - Quantizer can be used for in-place operations. - - """ - if self._quantizer is not None: - return self._quantizer - return MXFP8Quantizer( - fp8_dtype=self._fp8_dtype, - ) + def _build_default_quantizer(self) -> Optional[Quantizer]: + """Build default quantizer for the tensor""" + return MXFP8Quantizer(fp8_dtype=self._fp8_dtype) def quantize_( self, @@ -281,8 +293,7 @@ def quantize_( """ if isinstance(tensor, QuantizedTensor): return self.quantize_(tensor.dequantize()) - self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) - return self + return super().quantize_(tensor, noop_flag=noop_flag) def detach(self) -> MXFP8Tensor: # pylint: disable=missing-function-docstring @@ -317,7 +328,6 @@ def contiguous( memory_format: torch.memory_format = torch.contiguous_format, ) -> MXFP8Tensor: """Returns tensor with data in provided memory format - Returns `self` if data is already in correct memory format. """ @@ -333,7 +343,6 @@ def contiguous( @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): - # View op if func == aten.view.default: tensor = args[0] @@ -357,9 +366,339 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): fp8_dtype=tensor._fp8_dtype, ) + if func == torch.ops.aten.copy_.default: + dst, src = args[0], args[1] + if isinstance(src, MXFP8Tensor) and isinstance(dst, MXFP8Tensor): + # Booleans to check if src has all the usages that dst needs to respect dst quantizer usages. + # If not, default to base class behavior. + rowwise_matches = src._rowwise_data is not None or dst._rowwise_data is None + columnwise_matches = ( + src._columnwise_data is not None or dst._columnwise_data is None + ) + if rowwise_matches and columnwise_matches: + if dst._rowwise_data is not None: + dst._rowwise_data.copy_(src._rowwise_data.detach(), *args[2:], **kwargs) + dst._rowwise_scale_inv.copy_( + src._rowwise_scale_inv.detach(), *args[2:], **kwargs + ) + if dst._columnwise_data is not None: + dst._columnwise_data.copy_( + src._columnwise_data.detach(), *args[2:], **kwargs + ) + dst._columnwise_scale_inv.copy_( + src._columnwise_scale_inv.detach(), *args[2:], **kwargs + ) + return dst + + # FSDP2 related functions. + if func == aten.split.Tensor: + # This is called if entire model is initialized on CUDA device and + # then splitted. Finally the shard needed by the process is used + # and other splitted shards are discarded. + if "dim" in kwargs: + dim_to_split = kwargs["dim"] + else: + dim_to_split = args[2] if len(args) > 2 else 0 + tensor = args[0] + split_size = args[1] + dim0_size = tensor.size(0) + dimlast_size = math.prod(tensor.shape[1:]) + if ( + dim0_size % split_size != 0 + or dim_to_split != 0 + or split_size % MXFP8_BLOCK_SCALING_SIZE != 0 + or dimlast_size % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + # Handle splitting by dequantizing and splitting the hp tensor + return super().__torch_dispatch__(func, types, args, kwargs) + + out_data = [] + for data in [tensor._rowwise_data, tensor._columnwise_data]: + func_out = ( + data.__torch_dispatch__( + func, + types, + [data] + list(args[1:]), + kwargs, + ) + if data is not None + else None + ) + out_data.append(func_out) + + scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv] + split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE] + # Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4 + padding_multiples = [128, 4] + for scale_inv, scale_split_size, pad_multiple in zip( + scale_invs, split_sizes_for_scale, padding_multiples + ): + scale_inv_out = ( + scale_inv.__torch_dispatch__( + func, + types, + [scale_inv, scale_split_size] + list(args[2:]), + kwargs, + ) + if scale_inv is not None + else None + ) + # Pad scale_inv_out to be a multiple of pad_multiple + if scale_inv_out is not None: + current_shape = scale_inv_out.shape + pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple + if pad_dim0 > 0: + scale_inv_out = torch.nn.functional.pad(scale_inv_out, (0, 0, 0, pad_dim0)) + + out_data.append(scale_inv_out) + return [ + MXFP8Tensor( + shape=( + splitted_tensor_data[0].size() + if splitted_tensor_data[0] is not None + else splitted_tensor_data[1].size() + ), + dtype=tensor.dtype, + rowwise_data=splitted_tensor_data[0], + rowwise_scale_inv=splitted_tensor_data[2], + columnwise_data=splitted_tensor_data[1], + columnwise_scale_inv=splitted_tensor_data[3], + quantizer=tensor._quantizer, + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) + for splitted_tensor_data in zip(*out_data) + ] + if func == torch.ops.aten.as_strided.default: + # Applied on unsharded param in FSDP2. In our case, this should be a no-op + # This is needed for the case where some MXFP8 shards need padding i.e dimension 0 + # of the unsharded param is not a multiple of the world size. If that is the case, + # we down the dequantization route and weights are allgathered in high precision. + # If weight doesnt need padding, this is just a no-op. + shape = args[1] + strides = args[2] + tensor = args[0] + if ( + len(shape) != 2 + or len(strides) != 2 + or strides[1] != 1 + or shape[0] != tensor.shape[0] + or shape[1] != tensor.shape[1] + ): + return super().__torch_dispatch__(func, types, args, kwargs) + + return MXFP8Tensor.make_like(tensor) + + if func == aten.slice.Tensor: + # FSDP2 needed function. + # We need slicing for the case where some MXFP8 weight shards need padding i.e dimension 0 + # of the unsharded param is not a multiple of the world size. If that is the case, + # we down the dequantization route and weights are allgathered in high precision instead. + # If sharded weight doesnt have padding, this is just a no-op. + dim = args[1] + start = args[2] + length = args[3] + tensor = args[0] + if ( + dim != 0 + or length != tensor.shape[0] + or start != 0 + or length % MXFP8_BLOCK_SCALING_SIZE != 0 + or start % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + return super().__torch_dispatch__(func, types, args, kwargs) + return MXFP8Tensor.make_like(tensor) + + if func == aten.new_zeros.default: + rowwise_data = None + columnwise_data = None + rowwise_scale_inv = None + columnwise_scale_inv = None + tensor = args[0] + shape = args[1] + first_dim = math.prod(shape[:-1]) + last_dim = shape[-1] + if ( + first_dim % MXFP8_BLOCK_SCALING_SIZE != 0 + or last_dim % MXFP8_BLOCK_SCALING_SIZE != 0 + ): + return super().__torch_dispatch__(func, types, args, kwargs) + rowwise_scale_inv_shape = [first_dim, last_dim // MXFP8_BLOCK_SCALING_SIZE] + columnwise_scale_inv_shape = [ + first_dim // MXFP8_BLOCK_SCALING_SIZE, + last_dim, + ] + if tensor._rowwise_data is not None: + rowwise_data = tensor._rowwise_data.__torch_dispatch__( + func, + types, + [tensor._rowwise_data] + list(args[1:]), + kwargs, + ) + rowwise_scale_inv = tensor._rowwise_scale_inv.__torch_dispatch__( + func, + types, + [tensor._rowwise_scale_inv, rowwise_scale_inv_shape] + list(args[2:]), + kwargs, + ) + if tensor._columnwise_data is not None: + columnwise_data = tensor._columnwise_data.__torch_dispatch__( + func, + types, + [tensor._columnwise_data] + list(args[1:]), + kwargs, + ) + columnwise_scale_inv = tensor._columnwise_scale_inv.__torch_dispatch__( + func, + types, + [tensor._columnwise_scale_inv, columnwise_scale_inv_shape] + list(args[2:]), + kwargs, + ) + return MXFP8Tensor( + shape=args[1], + dtype=tensor.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + quantizer=tensor._quantizer.copy(), + requires_grad=False, + fp8_dtype=tensor._fp8_dtype, + ) # Default case return super().__torch_dispatch__(func, types, args, kwargs) + def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, mp_policy): + """Functions FSDP2 calls before all-gather of the + weights for both forward and backward passes. + Args: + mesh (torch.distributed.DeviceMesh): DeviceMesh used by FSDP2 + to shard the weights. + orig_size (torch.Size): Original size of the weight tensor.(For us same as self.shape) + contiguous_orig_stride (Tuple[int]): Original stride of the weight tensor + (For us same as self.stride()). + module (FSDPModule): FSDP module. FSDP wrapped module wrapped using fully_shard + that contains this MXFP8 tensor. + mp_policy (MixedPrecisionPolicy): Mixed precision policy used by FSDP2. + + Returns: + sharded_tensors: Tuple[torch.Tensor, ...]: Tuple of tensors + that need to be all-gathered. + metadata: Tuple[Any]: Metadata needed for reconstructing the + MXFP8Tensor after all-gather. + """ + # pylint: disable=unused-argument + from transformer_engine.pytorch.distributed import _get_module_fsdp_state + + fsdp_state = _get_module_fsdp_state(module) + reshard_after_forward = fsdp_state._fsdp_param_group._reshard_after_forward + quantizer = self._quantizer.copy() + # Remove padding from scale inverses before allgather + # Rowwise scale_inv should be divisible by [128,4], columnwise by [4, 128] + rowwise_scale_inv = self._rowwise_scale_inv + columnwise_scale_inv = self._columnwise_scale_inv + shape = self.shape + if rowwise_scale_inv is not None: + # Remove padding from rowwise scale_inv + flattened_in_shape0 = math.prod(shape[:-1]) + if rowwise_scale_inv.size(0) != flattened_in_shape0: + rowwise_scale_inv = rowwise_scale_inv[:flattened_in_shape0] + + if columnwise_scale_inv is not None: + # Remove padding from columnwise scale_inv + flattened_in_shape0 = math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE + if columnwise_scale_inv.size(0) != flattened_in_shape0: + columnwise_scale_inv = columnwise_scale_inv[:flattened_in_shape0] + + sharded_tensors = (self._rowwise_data, rowwise_scale_inv) + # If weights are resharded after forward pass, then its enough to set the quantizer usages + # based on whether its forward or backward pass for the allgathered weights. + # If not resharded after forward pass, the same weights allgathered in forward + # are used again in backward. And hence if we need the columnwise data/scale_inv, + # we need to send them as well for allgather in forward pass itself. + if reshard_after_forward: + training_state = fsdp_state._fsdp_param_group._training_state + is_backward_pass = training_state == TrainingState.PRE_BACKWARD + # Allgather only the necessary tensors based on forward/backward pass + quantizer.set_usage(rowwise=not is_backward_pass, columnwise=is_backward_pass) + sharded_tensors = ( + (self._columnwise_data, columnwise_scale_inv) + if is_backward_pass + else sharded_tensors + ) + else: + if quantizer.columnwise_usage: + # If weights are not resharded after forward, then both + # rowwise and columnwise data/scale_inv need to be allgathered. + sharded_tensors += (self._columnwise_data, columnwise_scale_inv) + metadata = (self._fp8_dtype, quantizer) + return sharded_tensors, metadata + + def fsdp_post_all_gather( + self, + all_gather_outputs: Tuple[torch.Tensor, ...], + metadata: Any, + param_dtype: torch.dtype, + *, + out: Optional[MXFP8Tensor] = None, + ): + """Functions FSDP2 calls after all-gather of the + weights for both forward and backward passes. + Args: + all_gather_outputs (Tuple[torch.Tensor, ...]): sharded_tensors sent out in fsdp_pre_all_gather from each rank + are all-gathered and received here as a tuple. + metadata (Any): metadata sent out in fsdp_pre_all_gather used for reconstructing the MXFP8Tensor. + param_dtype (torch.dtype): high precision dtype of the MXFP8Tensor. + out (Optional[torch.Tensor], optional): _description_. Defaults to None. + Returns: + Tuple[MXFP8Tensor, Tuple[torch.Tensor, ...]]: Allgathered MXFP8Tensor and tuple of internal tensors + used by the MXFP8Tensor that was being computed after allgather. + """ + fp8_dtype, quantizer = metadata + rowwise_data, rowwise_scale_inv = ( + all_gather_outputs[:2] if quantizer.rowwise_usage else (None, None) + ) + columnwise_data, columnwise_scale_inv = ( + all_gather_outputs[-2:] if quantizer.columnwise_usage else (None, None) + ) + + # Add padding to scale_inv tensors to be multiples of [128, 4]for rowwise and [4, 128] for columnwise + if rowwise_scale_inv is not None: + # Pad rowwise_scale_inv to be a multiple of [128, 4] + current_shape = rowwise_scale_inv.shape + pad_dim0 = (128 - current_shape[0] % 128) % 128 + if pad_dim0 > 0: + rowwise_scale_inv = torch.nn.functional.pad(rowwise_scale_inv, (0, 0, 0, pad_dim0)) + + if columnwise_scale_inv is not None: + # Pad columnwise_scale_inv to be a multiple of [4, 128] + current_shape = columnwise_scale_inv.shape + pad_dim0 = (4 - current_shape[0] % 4) % 4 + if pad_dim0 > 0: + columnwise_scale_inv = torch.nn.functional.pad( + columnwise_scale_inv, (0, 0, 0, pad_dim0) + ) + + if out is not None: + out._rowwise_data = rowwise_data + out._rowwise_scale_inv = rowwise_scale_inv + out._columnwise_data = columnwise_data + out._columnwise_scale_inv = columnwise_scale_inv + out._quantizer = quantizer + else: + out = MXFP8Tensor( + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=fp8_dtype, + dtype=param_dtype, + shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + quantizer=quantizer, + ) + + return out, all_gather_outputs + @classmethod def _make_in_reduce_ex( cls, @@ -497,10 +836,14 @@ def forward( shape[i] = d_inferred break if shape[-1] != ctx.shape[-1]: - raise RuntimeError( - "MXFP8Tensor does not support reshaping inner dimension " + warnings.warn( + "MXFP8Tensor does not support reshaping inner dimension. " f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + "If you are using this for FSDP2 without compiled_autograd_enabled," + "then ignore this warning. Since this view is not going to be used anywhere. ", + stacklevel=2, ) + return tensor.dequantize().view(*shape) # Construct new tensor if shape is provided new_rowwise_data = None diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py new file mode 100644 index 000000000..31dbcf00a --- /dev/null +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -0,0 +1,926 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tensor class with NVFP4 data""" +from __future__ import annotations +from collections.abc import Iterable +import math +from typing import Dict, Optional, Tuple, Union +import functools + +import torch +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from transformer_engine.common.recipe import NVFP4BlockScaling, Recipe +from ..constants import NVFP4_BLOCK_SCALING_SIZE, dist_group_type +from ..utils import ( + canonicalize_process_group, + devices_match, + round_up_to_nearest_multiple, +) + +from .storage.nvfp4_tensor_storage import NVFP4TensorStorage, _FromNVFP4Func +from ..quantized_tensor import QuantizedTensor, Quantizer +from ._quantization_helpers import _IdentityFunc + +aten = torch.ops.aten + + +def get_no_random_sign_vector() -> torch.Tensor: + """Non-random sign vector for Hadamard transform.""" + return torch.tensor([1], dtype=torch.float32, device="cuda") + + +def get_sign_from_vector(vector: torch.Tensor) -> int: + """Convert sign vector to bitmask. + + Used for random Hadamard transform. + + """ + mask = 0 + for i, v in enumerate(vector): + mask |= (v == -1) << i + return mask.item() + + +def get_wgrad_sign_vector() -> torch.Tensor: + """Hard-coded random signs for Hadamard transform. + + https://xkcd.com/221/ + + """ + return torch.tensor( + [1, 1, 1, -1, 1, -1, -1, -1, -1, -1, -1, 1, -1, 1, -1, -1], + dtype=torch.float32, + device="cuda", + ) + + +def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor: + """Construct a 16x16 Hadamard matrix.""" + assert hadamard_dimension == 16, "Only hadamard dimension 16 is supported." + hadamard_scale = 1 / math.sqrt(hadamard_dimension) + return ( + torch.tensor( + [ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1, 1, -1], + [1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1], + [1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1, 1, -1, -1, 1], + [1, 1, 1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1], + [1, -1, 1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1], + [1, 1, -1, -1, -1, -1, 1, 1, 1, 1, -1, -1, -1, -1, 1, 1], + [1, -1, -1, 1, -1, 1, 1, -1, 1, -1, -1, 1, -1, 1, 1, -1], + [1, 1, 1, 1, 1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1], + [1, -1, 1, -1, 1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1], + [1, 1, -1, -1, 1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1], + [1, -1, -1, 1, 1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1], + [1, 1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1, 1, 1, 1, 1], + [1, -1, 1, -1, -1, 1, -1, 1, -1, 1, -1, 1, 1, -1, 1, -1], + [1, 1, -1, -1, -1, -1, 1, 1, -1, -1, 1, 1, 1, 1, -1, -1], + [1, -1, -1, 1, -1, 1, 1, -1, -1, 1, 1, -1, 1, -1, -1, 1], + ], + dtype=torch.float32, + device="cuda", + ) + * hadamard_scale + ) + + +@functools.lru_cache(maxsize=None) +def get_rht_matrix(with_random_sign_mask: bool) -> torch.Tensor: + """Construct matrix used in random Hadamard transform.""" + hadamard_dimension = 16 + if with_random_sign_mask: + signs = get_wgrad_sign_vector() + else: + signs = get_no_random_sign_vector() + sign_matrix = signs * torch.eye(hadamard_dimension, dtype=torch.float32, device="cuda") + rht_matrix = sign_matrix @ get_hadamard_matrix(hadamard_dimension) + return rht_matrix.to(dtype=torch.bfloat16) + + +@functools.lru_cache(maxsize=None) +def get_random_sign_mask_for_rht(with_random_sign_mask: bool) -> int: + """Sign mask for random Hadamard transform.""" + if with_random_sign_mask: + return get_sign_from_vector(get_wgrad_sign_vector()) + return 0 + + +class NVFP4Quantizer(Quantizer): + """Builder class for NVFP4 tensors with NV block scaling""" + + dtype: TE_DType + """Random Hadamard Transform""" + with_rht: bool + with_post_rht_amax: bool + """amax reduction options""" + with_amax_reduction: bool + amax_reduction_group: Optional[dist_group_type] + + """2D block scaling, only applicable for weights.""" + with_2d_quantization: bool + + """Stochastic rounding, only applicable for gradients.""" + stochastic_rounding: bool + + """RHT matrix random sign mask""" + rht_matrix_random_sign_mask_t: int + rht_matrix: torch.Tensor + + def __init__( + self, + fp4_dtype: TE_DType = tex.DType.kFloat4E2M1, + rowwise: bool = True, + columnwise: bool = True, + with_amax_reduction: bool = False, + amax_reduction_group: Optional[dist_group_type] = None, + with_rht: bool = False, + with_post_rht_amax: bool = False, + with_2d_quantization: bool = False, + stochastic_rounding: bool = False, + with_random_sign_mask: bool = True, + ) -> None: + super().__init__(rowwise=rowwise, columnwise=columnwise) + self.dtype = fp4_dtype + self.with_rht = with_rht + self.with_post_rht_amax = with_post_rht_amax + self.with_amax_reduction = with_amax_reduction + self.amax_reduction_group = amax_reduction_group + self.with_2d_quantization = with_2d_quantization + self.stochastic_rounding = stochastic_rounding + self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht(with_random_sign_mask) + self.rht_matrix = get_rht_matrix(with_random_sign_mask) + + def update_quantized( + self, + src: torch.Tensor, + dst: QuantizedTensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> QuantizedTensor: + + assert isinstance(dst, NVFP4Tensor), f"Cannot store quantized NVFP4 in {type(dst)} type." + + # Make sure input is in expected format + if not devices_match(src.device, dst.device): + src = src.to(device=dst.device) + if not src.is_contiguous(): + src = src.contiguous() + + # Launch cast kernel + tex.quantize(src, self, dst, noop_flag) + + return dst + + def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: + """Quantize tensor implementation""" + return tex.quantize(tensor, self) + + def is_quantizable(self, inp: torch.Tensor) -> bool: + """Returns whether or not given inp can be quantized""" + if inp.ndim < 2: + return False + if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + if math.prod(inp.shape[:-1]) % NVFP4_BLOCK_SCALING_SIZE != 0: + return False + return True + + def get_scale_shape(self, shape: Iterable[int], columnwise: bool) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for NVFP4 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For NVFP4 1D blockwise quantization, blocksize is 16 + - If columnwise: (round_to_multiple(K, 128), round_to_multiple(roundup(M / 16), 4)) + - If rowwise: (round_to_multiple(M, 128), round_to_multiple(roundup(K / 16), 4)) + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + M, K = 1, 1 + M = math.prod(shape[:-1]) + K = shape[-1] + + if columnwise: + outer = round_up_to_nearest_multiple(K, 128) + inner = round_up_to_nearest_multiple(math.ceil(M / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + # rowwise + outer = round_up_to_nearest_multiple(M, 128) + inner = round_up_to_nearest_multiple(math.ceil(K / NVFP4_BLOCK_SCALING_SIZE), 4) + return (outer, inner) + + @staticmethod + def get_columnwise_shape(shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of a tensor after columnwise quantization. + + For NVFP4 columnwise quantization, it's performing 16x1 quantization block scaling. + + Parameters + ---------- + shape : Iterable[int] + Original shape of the tensor + + Returns + ------- + Tuple[int, ...] + New shape with dimensions rearranged for columnwise layout. + For a shape (d1, d2, ..., dn), returns (dn, d1, d2, ..., dn-1). + Returns empty tuple for empty input shape. + """ + if len(shape) == 0: + return tuple() + # and then after AG, a reorganize kernel will be called to restore the shape + colwise_shape = [shape[-1]] + for i in range(len(shape) - 1): + colwise_shape.append(shape[i]) + return tuple(colwise_shape) + + @staticmethod + def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: + """Convert shape for FP4 data by dividing the last dimension by 2""" + shape = list(shape) + shape[-1] = shape[-1] // 2 + return tuple(shape) + + def make_empty( + self, + shape: Iterable[int], + *, + dtype: torch.dtype = torch.float32, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, + ) -> NVFP4Tensor: + + # Canonicalize tensor attributes + if device is None: + device = torch.device("cuda") + + assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + flat_first_dim = math.prod(shape[:-1]) + assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( + f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" + f" {NVFP4_BLOCK_SCALING_SIZE}" + ) + + # Allocate FP4 data + data = None + scale_inv = None + amax_rowwise = None + if self.rowwise_usage: + data = torch.empty( + self.convert_shape_for_fp4(shape), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, + ) + scale_shape = self.get_scale_shape(shape, columnwise=False) + scale_inv = torch.empty( + scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + # Allocate per tensor scale inverse. FP32 format. + amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + + # Allocate FP8 data transpose if needed + columnwise_data = None + columnwise_scale_inv = None + amax_columnwise = None + if self.columnwise_usage: + # enforce 2D shape to avoid [S, B, H] shape and B and be 1 + # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero + shape_2d = tuple([flat_first_dim, shape[-1]]) + columnwise_data = torch.empty( + self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, + ) + columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) + columnwise_scale_inv = torch.empty( + columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + ) + amax_columnwise = torch.zeros( + 1, dtype=torch.float32, device=device, pin_memory=pin_memory + ) + + # Construct FP8 tensor + return NVFP4Tensor( + shape=shape, + dtype=dtype, + rowwise_data=data, + rowwise_scale_inv=scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.dtype, + quantizer=self, + requires_grad=requires_grad, + ) + + def calibrate(self, tensor: torch.Tensor) -> None: + pass # Calibration is no-op + + def _canonicalized_amax_reduction_group(self) -> dist_group_type: + """Get process group for amax reduction""" + return canonicalize_process_group(self.amax_reduction_group) + + def _get_compatible_recipe(self) -> Union[type[Recipe], None]: + return NVFP4BlockScaling + + +class NVFP4Tensor(NVFP4TensorStorage, QuantizedTensor): + """Quantized tensor class with FP4 data + + The tensor presents as having a standard, higher-precision dtype, + but the data itself is (scaled) FP4. For most tensor operations, + the data will be cast to the nominal dtype before performing the + operation. + + Parameters + ---------- + rowwise_data: torch.Tensor + Raw FP4 data in a uint8 tensor (rowwise layout). + rowwise_scale_inv: torch.Tensor + Reciprocal of the scaling factor applied when + casting to FP4, i.e. the scaling factor that must + be applied when casting from FP4 to higher + precision (rowwise). + columnwise_data: torch.Tensor, optional + Raw FP4 data in a uint8 tensor (columnwise layout). + columnwise_scale_inv: torch.Tensor, optional + Reciprocal of the scaling factor for columnwise FP4 data. + amax_rowwise: torch.Tensor, optional + Rowwise amax tracking tensor. + amax_columnwise: torch.Tensor, optional + Columnwise amax tracking tensor. + fp4_dtype: TE_DType + The FP4 data type used for quantization. + quantizer: Quantizer + The quantizer instance used for this tensor. + dtype: torch.dtype, default = torch.float32 + Nominal tensor datatype, used in dequantize. + """ + + # NOTE: We reorder the *args so that we can instantiate a NVFP4TensorStorage with positional args, + # which significantly reduces the Pybind11 overhead when calling the constructor from C++. + def __new__( + cls, + *args, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: Optional[torch.Tensor], + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: Optional[torch.Tensor], + amax_rowwise: Optional[torch.Tensor], + amax_columnwise: Optional[torch.Tensor], + fp4_dtype: TE_DType, + quantizer: Quantizer, + **kwargs, + ): + instance = super().__new__( + cls, + rowwise_data, + rowwise_scale_inv, + columnwise_data, + columnwise_scale_inv, + amax_rowwise, + amax_columnwise, + fp4_dtype, + quantizer, + *args, + **kwargs, + ) + return instance + + def __repr__(self, *, tensor_contents=None): + return f"NVFP4Tensor, data={self.dequantize(dtype=self.dtype)})" + + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + Construct plain PyTorch tensor from NVFP4Tensor + + By default the resulting tensor's dtype is the + NVFP4Tensor's nominal dtype. + """ + # Convert PyTorch dtype to TE dtype + if dtype is None: + dtype = self.dtype + + if torch.is_grad_enabled(): + return _FromNVFP4Func.apply(self, dtype) + return _FromNVFP4Func.forward(None, self, dtype) + + def _get_quantizer(self) -> Quantizer: + """Get builder for quantized tensor + + Quantizer can be used for in-place operations. + + """ + if self._quantizer is not None: + return self._quantizer + return NVFP4Quantizer() + + def quantize_( + self, + tensor: torch.Tensor, + *, + noop_flag: Optional[torch.Tensor] = None, + ) -> NVFP4Tensor: + """Update FP8 data + + Parameters + ---------- + tensor: torch.Tensor + Tensor to copy from + noop_flag: torch.Tensor, optional + float32 flag indicating whether to avoid performing update + + """ + if isinstance(tensor, QuantizedTensor): + return self.quantize_(tensor.dequantize()) + self._get_quantizer().update_quantized(tensor, self, noop_flag=noop_flag) + return self + + def detach(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + # TODO(ksivamani): Fix the detach bug + return NVFP4Tensor.make_like(self) + + def clone(self) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + assert self._rowwise_data is not None + rowwise_data = self._rowwise_data.detach().clone() + columnwise_data = None + if self._columnwise_data is not None: + columnwise_data = self._columnwise_data.detach().clone() + return _IdentityFunc.apply( + self, + { + "rowwise_data": rowwise_data, + "columnwise_data": columnwise_data, + }, + ) + + def view(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ViewFunc.apply(self, shape) + + def reshape(self, *shape: Tuple[int]) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + return _ReshapeFunc.apply(self, shape) + + def contiguous( + self, + memory_format: torch.memory_format = torch.contiguous_format, + ) -> NVFP4Tensor: + """Returns tensor with data in provided memory format + + Returns `self` if data is already in correct memory format. + + """ + if self._rowwise_data is not None and self._rowwise_data.is_contiguous( + memory_format=memory_format + ): + return self + if self._columnwise_data is not None and self._columnwise_data.is_contiguous( + memory_format=memory_format + ): + return self + raise ValueError("NVFP4Tensor does not support different memory formats!") + + def get_usages(self) -> Dict[str, bool]: + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs=None): + + # View op + if func == aten.view.default: + if len(args) != 2: + raise RuntimeError("Unexpected args for view op (expected 2 args, got {len(args)})") + tensor = args[0] + shape = args[1] + if shape == list(tensor.size()): + return tensor.detach() + return tensor.view(shape) + + # NVFP4 dequantize not supported. Add manual support for needed funcs. + if func in (aten.empty_like.default, aten.zero_.default): + tensor = args[0] + data_init_func = torch.zeros_like if func == aten.zero_.default else torch.empty_like + scale_inv_init_func = ( + torch.ones_like if func == aten.zero_.default else torch.empty_like + ) + + if tensor._rowwise_data is not None: + rowwise_data = data_init_func(tensor._rowwise_data, *args[1:], **kwargs) + rowwise_scale_inv = scale_inv_init_func( + tensor._rowwise_scale_inv, *args[1:], **kwargs + ) + amax_rowwise = torch.zeros_like(tensor._amax_rowwise, *args[1:], **kwargs) + else: + rowwise_data, rowwise_scale_inv, amax_rowwise = None, None, None + + if tensor._columnwise_data is not None: + columnwise_data = data_init_func(tensor._columnwise_data, *args[1:], **kwargs) + columnwise_scale_inv = scale_inv_init_func( + tensor._columnwise_scale_inv, *args[1:], **kwargs + ) + amax_columnwise = torch.zeros_like(tensor._amax_columnwise, *args[1:], **kwargs) + else: + columnwise_data, columnwise_scale_inv, amax_columnwise = ( + None, + None, + None, + ) + + return NVFP4Tensor( + shape=tensor.shape, + dtype=tensor.dtype, + fp4_dtype=tensor._fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=tensor._quantizer, + requires_grad=tensor.requires_grad, + ) + + # Default case + return super().__torch_dispatch__(func, types, args, kwargs) + + @classmethod + def _make_in_reduce_ex( + cls, + shape: torch.Size, + rowwise_data: torch.Tensor, + rowwise_scale_inv: torch.Tensor, + columnwise_data: torch.Tensor, + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + dtype: torch.dtype, + quantizer: Quantizer, + ) -> NVFP4Tensor: + """Build NVFP4Tensor, for use in __reduce__ + + __reduce_ex__ assumes object constructor has positional + arguments. + + """ + return NVFP4Tensor( + shape=shape, + dtype=dtype, + fp4_dtype=fp4_dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + quantizer=quantizer, + requires_grad=False, + ) + + def __reduce_ex__(self, protocol: int) -> tuple: + """Custom pickling""" + return ( + NVFP4Tensor._make_in_reduce_ex, + ( + self.shape, + self._rowwise_data, + self._rowwise_scale_inv, + self._columnwise_data, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + self._fp4_dtype, + self.dtype, + self._quantizer, + ), + ) + + def _get_data(self) -> NVFP4Tensor: + """Get tensor data property""" + return super().data + + @torch.no_grad() + def _set_data(self, tensor: torch.Tensor) -> None: + """Set tensor data property + + Just takes FP8 data if setting from a NVFP4Tensor. Otherwise + casts to FP8. + + """ + + # Tensor device + new_device = tensor.device if tensor.is_cuda else self.device + if not devices_match(new_device, tensor.device): + tensor = tensor.to(device=new_device) + + # Just copy FP8 data if other tensor is NVFP4Tensor + if isinstance(tensor, NVFP4Tensor): + if ( # pylint: disable=too-many-boolean-expressions + self.size() != tensor.size() + or self.stride() != tensor.stride() + or self.storage_offset() != tensor.storage_offset() + or self.dtype != tensor.dtype + or self.layout != tensor.layout + or not devices_match(self.device, new_device) + ): + dummy_tensor = torch.Tensor._make_wrapper_subclass( + NVFP4Tensor, + tensor.size(), + strides=tensor.stride(), + storage_offset=tensor.storage_offset(), + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + device=new_device, + ) + # pylint: disable=unnecessary-dunder-call + super(NVFP4Tensor, type(self)).data.__set__(self, dummy_tensor) + self._rowwise_data = tensor._rowwise_data + self._columnwise_data = tensor._columnwise_data + self._quantizer = tensor._quantizer + self._rowwise_scale_inv = tensor._rowwise_scale_inv + self._columnwise_scale_inv = tensor._columnwise_scale_inv + self._amax_rowwise = tensor._amax_rowwise + self._amax_columnwise = tensor._amax_columnwise + return + + # Quantize to FP8 + assert self._quantizer is not None, "Can't quantize without a quantizer" + self._quantizer.update_quantized(tensor, self) + if self.requires_grad != tensor.requires_grad: + self.requires_grad_(requires_grad=tensor.requires_grad) + + # Cast to FP8 when setting NVFP4Tensor.data + data = property(_get_data, _set_data) + + +class _ViewFunc(torch.autograd.Function): + """View function + + View the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.view(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.view(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.view(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None + + +class _ReshapeFunc(torch.autograd.Function): + """Reshape function + + Reshape the NVFP4Tensor using the provided shape. + + """ + + @staticmethod + def forward( + ctx, + tensor: NVFP4Tensor, + shape: Optional[list[int]] = None, + ) -> NVFP4Tensor: + # pylint: disable=missing-function-docstring + + # Return input tensor if shape is not provided + cur_shape = tensor.shape + if ctx is not None: + ctx.shape = cur_shape + if shape is None: + return tensor + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(tensor.shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if tensor._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = tensor._rowwise_data.reshape(byte_shape) + if tensor._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = tensor._columnwise_data.reshape(byte_shape) + + # Construct tensor + return NVFP4Tensor( + shape, + tensor.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=tensor._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=tensor._columnwise_scale_inv, + amax_rowwise=tensor._amax_rowwise, + amax_columnwise=tensor._amax_columnwise, + quantizer=tensor._quantizer, + fp4_dtype=tensor._fp4_dtype, + requires_grad=tensor.requires_grad, + ) + + @staticmethod + def backward( + ctx, + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + + if isinstance(grad, NVFP4Tensor): + new_rowwise_data = None + new_columnwise_data = None + if grad._rowwise_data is not None: + if ctx.shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = list(ctx.shape[:-1]) + [ctx.shape[-1] // 2] + new_rowwise_data = grad._rowwise_data.reshape(byte_shape) + if grad._columnwise_data is not None: + columnwise_shape = (ctx.shape[-1], math.prod(ctx.shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={ctx.shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = grad._columnwise_data.reshape(byte_shape) + dgrad = NVFP4Tensor( + ctx.shape, + grad.dtype, + rowwise_data=new_rowwise_data, + rowwise_scale_inv=grad._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=grad._columnwise_scale_inv, + amax_rowwise=grad._amax_rowwise, + amax_columnwise=grad._amax_columnwise, + quantizer=grad._quantizer, + fp4_dtype=grad._fp4_dtype, + requires_grad=grad.requires_grad, + ) + return dgrad, None + return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py new file mode 100644 index 000000000..9cb228f3a --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Storage for quantized tensors.""" + +from .float8_tensor_storage import Float8TensorStorage # noqa: F401 +from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 +from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py similarity index 97% rename from transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index da0220eb7..38d117b2a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -13,16 +13,14 @@ from transformer_engine_torch import DType as TE_DType from transformer_engine_torch import Float8BlockScaleTensorFormat -from ..quantized_tensor import QuantizedTensorBase +from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...constants import TE_DType_To_Torch -from ..quantized_tensor import Quantizer - from ...utils import _empty_tensor -class Float8BlockwiseQTensorBase(QuantizedTensorBase): +class Float8BlockwiseQTensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8BlockwiseQTensor. Float8BlockwiseQTensor inherits from the PyTorch tensor class and this @@ -53,7 +51,7 @@ def __new__( *args, **kwargs, ): - if cls is Float8BlockwiseQTensorBase: + if cls is Float8BlockwiseQTensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -98,7 +96,7 @@ def _is_gemm_ready_format(self) -> bool: def prepare_for_saving( self, - ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorBase]: + ) -> Tuple[list[Optional[torch.Tensor]], Float8BlockwiseQTensorStorage]: """ Prepare the tensor base for saving for backward """ @@ -366,7 +364,7 @@ def __repr__(self): data = self.dequantize() descriptor = "columnwise" return ( - "Float8BlockwiseQTensorBase(" + "Float8BlockwiseQTensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"{descriptor}_scaled_data={data}" ) @@ -422,3 +420,10 @@ def update_usage( return return + + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py similarity index 92% rename from transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index 6d4822344..8d12c3070 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -12,12 +12,10 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from ..quantized_tensor import QuantizedTensorBase +from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...constants import TE_DType as torch_to_transformer_engine_dtype -from ..quantized_tensor import Quantizer - from ...utils import is_non_tn_fp8_gemm_supported, _empty_tensor @@ -27,7 +25,7 @@ class _FromFloat8Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: Float8TensorBase, + tensor: Float8TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -52,7 +50,7 @@ def backward( return grad, None -class Float8TensorBase(QuantizedTensorBase): +class Float8TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of Float8Tensor. Float8Tensor inherits from the PyTorch tensor class and this mixin @@ -81,7 +79,7 @@ def __new__( quantizer: Optional[Quantizer] = None, **kwargs, ): - if cls is Float8TensorBase: + if cls is Float8TensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -116,7 +114,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [self._data, self._transpose, self._scale_inv] self._data = None @@ -163,7 +161,7 @@ def view(self, shape: torch.Size): if out_transpose_shape[0] != shape[-1] or out_transpose_shape[1:] != shape[:-1]: out_transpose = None - return Float8TensorBase( + return Float8TensorStorage( data=out_data, fp8_scale_inv=self._scale_inv, fp8_dtype=self._fp8_dtype, @@ -173,7 +171,7 @@ def view(self, shape: torch.Size): def __repr__(self): return ( - "Float8TensorBase(" + "Float8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"scale_inv={self._scale_inv.item()}, " f"data={self.dequantize()}" @@ -227,3 +225,12 @@ def update_usage( if not needs_data_transpose: self._transpose = None self._transpose_invalid = True + + def get_usages(self) -> Dict[str, bool]: + """Get the usage of the tensor""" + usages = {"rowwise": self._data is not None} + if is_non_tn_fp8_gemm_supported(): + usages["columnwise"] = self._data is not None + else: + usages["columnwise"] = self._transpose is not None and not self._transpose_invalid + return usages diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py similarity index 94% rename from transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py rename to transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 11055c4cc..f05dca705 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -17,12 +17,10 @@ from transformer_engine_torch import DType as TE_DType from torch.utils.cpp_extension import IS_HIP_EXTENSION -from ..quantized_tensor import QuantizedTensorBase +from ...quantized_tensor import QuantizedTensorStorage, Quantizer from ...constants import TE_DType as torch_to_transformer_engine_dtype -from ..quantized_tensor import Quantizer - from ...utils import _empty_tensor @@ -32,7 +30,7 @@ class _FromMXFP8Func(torch.autograd.Function): @staticmethod def forward( _ctx: Optional[torch.autograd.function.FunctionCtx], # unused - tensor: MXFP8TensorBase, + tensor: MXFP8TensorStorage, dtype: torch.dtype, ) -> torch.Tensor: # pylint: disable=missing-function-docstring @@ -57,7 +55,7 @@ def backward( return grad, None -class MXFP8TensorBase(QuantizedTensorBase): +class MXFP8TensorStorage(QuantizedTensorStorage): """Mixin class that holds data attributes of MXFP8Tensor. MXFP8Tensor inherits from the PyTorch tensor class and this mixin @@ -85,7 +83,7 @@ def __new__( *args, **kwargs, ): - if cls is MXFP8TensorBase: + if cls is MXFP8TensorStorage: instance = object.__new__(cls) else: instance = super().__new__(cls, *args, **kwargs) @@ -120,7 +118,7 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, } - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorBase]: + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], MXFP8TensorStorage]: """Prepare the tensor base for saving for backward""" tensors = [ self._rowwise_data, @@ -200,7 +198,7 @@ def view(self, shape: torch.Size): if cur_columnwise_data is not None: new_columnwise_data = cur_columnwise_data.view(*shape) - return MXFP8TensorBase( + return MXFP8TensorStorage( rowwise_data=new_rowwise_data, rowwise_scale_inv=self._rowwise_scale_inv, columnwise_data=new_columnwise_data, @@ -213,7 +211,7 @@ def __repr__(self): data_rowwise = self.dequantize() return ( - "MXFP8TensorBase(" + "MXFP8TensorStorage(" f"fp8_dtype={self._fp8_dtype}, " f"rowwise_scaled_data={data_rowwise}" f"rowwise_scale_inv={self._rowwise_scale_inv}, " @@ -264,3 +262,10 @@ def update_usage( else: self._columnwise_data = None self._columnwise_scale_inv = None + + def get_usages(self) -> Tuple[bool, bool]: + """Get the usage of the tensor""" + return { + "rowwise": self._rowwise_data is not None, + "columnwise": self._columnwise_data is not None, + } diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py new file mode 100644 index 000000000..04ab092ee --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -0,0 +1,320 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Mixin class holding data specific for NVFP4Tensor""" + +from __future__ import annotations +from collections.abc import Iterable +import functools +import math +from typing import Any, Dict, Optional, Tuple, Union +import warnings + +import torch + +import transformer_engine_torch as tex +from transformer_engine_torch import DType as TE_DType + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ...constants import TE_DType as torch_to_transformer_engine_dtype +from ...utils import _empty_tensor + + +@functools.lru_cache(maxsize=None) +def _fp4_e2m1_vals(device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Values representable in FP4 E2M1 format""" + return torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], + device=device, + dtype=dtype, + ) + + +class _FromNVFP4Func(torch.autograd.Function): + """Cast from NVFP4 to other dtype""" + + @staticmethod + def forward( + _ctx: Optional[torch.autograd.function.FunctionCtx], # unused + tensor: NVFP4TensorStorage, + dtype: torch.dtype, + ) -> torch.Tensor: + # pylint: disable=missing-function-docstring + + # Dequantize row-wise data + if tensor._rowwise_data is not None: + return tex.dequantize(tensor, torch_to_transformer_engine_dtype[dtype]) + + if tensor._columnwise_data is not None: + raise NotImplementedError("Dequantizing column-wise NVFP4 data is not implemented yet!") + raise ValueError("Attempted to dequantize NVFP4 tensor with no data") + + @staticmethod + def backward( + _ctx: torch.autograd.function.FunctionCtx, # unused + grad: torch.Tensor, + ) -> Tuple[Optional[torch.Tensor], ...]: + # pylint: disable=missing-function-docstring + # Assume that we want gradients in full precision + return grad, None + + +class NVFP4TensorStorage(QuantizedTensorStorage): + """Mixin class that holds data attributes of NVFP4Tensor. + + NVFP4Tensor inherits from the PyTorch tensor class and this mixin + class. If this class is instantiated directly, it has the same + data, lower CPU overhead, and less functionality. It should only + be instantiated directly for performance-critical internal usage. + + """ + + _rowwise_data: Optional[torch.Tensor] + _columnwise_data: Optional[torch.Tensor] + _quantizer: Optional[Quantizer] + _rowwise_scale_inv: torch.Tensor + _columnwise_scale_inv: torch.Tensor + _fp4_dtype: TE_DType + _amax_rowwise: torch.Tensor + _amax_columnwise: torch.Tensor + + def __new__( + cls, + rowwise_data: Optional[torch.Tensor], + rowwise_scale_inv: torch.Tensor, + columnwise_data: Optional[torch.Tensor], + columnwise_scale_inv: torch.Tensor, + amax_rowwise: torch.Tensor, + amax_columnwise: torch.Tensor, + fp4_dtype: TE_DType, + quantizer: Optional[Quantizer], + *args, + **kwargs, + ): + + instance = super().__new__(cls, *args, **kwargs) + + instance._rowwise_data = rowwise_data + instance._columnwise_data = columnwise_data + instance._fp4_dtype = fp4_dtype + instance._quantizer = quantizer.copy() if quantizer is not None else None + instance._rowwise_scale_inv = rowwise_scale_inv + instance._columnwise_scale_inv = columnwise_scale_inv + instance._amax_rowwise = amax_rowwise + instance._amax_columnwise = amax_columnwise + + return instance + + def clear(self): + """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" + for t in ( + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ): + if t is not None: + t.data = _empty_tensor() + + def get_metadata(self) -> Dict[str, Any]: + """Get this tensor's metadata.""" + return { + "rowwise_data": self._rowwise_data, + "rowwise_scale_inv": self._rowwise_scale_inv, + "columnwise_data": self._columnwise_data, + "columnwise_scale_inv": self._columnwise_scale_inv, + "amax_rowwise": self._amax_rowwise, + "amax_columnwise": self._amax_columnwise, + "fp4_dtype": self._fp4_dtype, + "quantizer": self._quantizer, + } + + def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], NVFP4TensorStorage]: + """Prepare the tensor base for saving for backward""" + tensors = [ + self._rowwise_data, + self._columnwise_data, + self._rowwise_scale_inv, + self._columnwise_scale_inv, + self._amax_rowwise, + self._amax_columnwise, + ] + self._rowwise_data = None + self._columnwise_data = None + self._rowwise_scale_inv = None + self._columnwise_scale_inv = None + self._amax_rowwise = None + self._amax_columnwise = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self._rowwise_data = tensors[0] + self._columnwise_data = tensors[1] + self._rowwise_scale_inv = tensors[2] + self._columnwise_scale_inv = tensors[3] + self._amax_rowwise = tensors[4] + self._amax_columnwise = tensors[5] + return tensors[6:] + + def get_data_tensors(self): + """Get this Tensor's data.""" + return self._rowwise_data, self._columnwise_data + + def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Dequantize to a higher precision.""" + return _FromNVFP4Func.forward(None, self, dtype) + + def size(self, dim: Optional[int] = None) -> Union[torch.Size, int]: + # pylint: disable=missing-function-docstring + + # Infer tensor shape + shape = None + if self._rowwise_data is not None: + byte_shape = list(self._rowwise_data.size()) + shape = byte_shape[:-1] + [byte_shape[-1] * 2] + elif self._columnwise_data is not None: + warnings.warn("Attempting to get shape of NVFP4 tensor with only column-wise data.") + byte_shape = list(self._columnwise_data.size()) + shape = byte_shape[1:-1] + [byte_shape[-1] * 2, byte_shape[0]] + if shape is None: + raise RuntimeError("Attempted to get shape of NVFP4 tensor with no data") + + # Return shape or dim + if dim is None: + return torch.Size(shape) + return shape[dim] + + def view(self, shape: torch.Size): + # pylint: disable=missing-function-docstring + + # Return input tensor if view not needed + cur_shape = self.size() + if shape is None or shape == cur_shape: + return self + + # Canonicalize shape + if not isinstance(shape, Iterable): + shape = [shape] + elif len(shape) == 1 and isinstance(shape[0], Iterable): + shape = shape[0] + if -1 in shape: + shape = list(shape) + d_inferred = -math.prod(cur_shape) // math.prod(shape) + for i, d in enumerate(shape): + if d == -1: + shape[i] = d_inferred + break + if shape[-1] != cur_shape[-1]: + raise RuntimeError( + "NVFP4Tensor does not support reshaping inner dimension " + f"(attempted to reshape dims={tuple(cur_shape)} to {tuple(shape)})" + ) + + # Reshape data + new_rowwise_data = None + new_columnwise_data = None + if self._rowwise_data is not None: + if shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent row-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = list(shape[:-1]) + [shape[-1] // 2] + new_rowwise_data = self._rowwise_data.view(byte_shape) + if self._columnwise_data is not None: + columnwise_shape = (shape[-1], math.prod(shape[:-1])) + if columnwise_shape[-1] % 2 != 0: + raise ValueError( + "Cannot represent column-wise data for NVFP4 tensor " + f"with shape={shape} as byte array." + ) + byte_shape = (columnwise_shape[0], columnwise_shape[1] // 2) + new_columnwise_data = self._columnwise_data.view(byte_shape) + + # Construct tensor + return NVFP4TensorStorage( + rowwise_data=new_rowwise_data, + rowwise_scale_inv=self._rowwise_scale_inv, + columnwise_data=new_columnwise_data, + columnwise_scale_inv=self._columnwise_scale_inv, + amax_rowwise=self._amax_rowwise, + amax_columnwise=self._amax_columnwise, + quantizer=self._quantizer, + fp4_dtype=self._fp4_dtype, + ) + + def __repr__(self): + data_rowwise = self.dequantize() + + return ( + "NVFP4TensorStorage(" + f"rowwise_scaled_data={data_rowwise}," + f"rowwise_scale_inv={self._rowwise_scale_inv}," + f"amax_rowwise={self._amax_rowwise}," + f"amax_columnwise={self._amax_columnwise}," + ")" + ) + + def update_usage( + self, + rowwise_usage: Optional[bool] = None, + columnwise_usage: Optional[bool] = None, + ): + """ + For the NVFP4 format, columnwise scaled output is only produced by x2 + scaling kernels, so this function only disables usages. + """ + + # Default usage is based on available data + if rowwise_usage is None: + rowwise_usage = self._rowwise_data is not None + if columnwise_usage is None: + columnwise_usage = self._columnwise_data is not None + + # Update row-scaled data + if rowwise_usage: + if self._rowwise_data is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled NVFP4 data" + ) + if self._rowwise_scale_inv is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing row-scaled scale-inverses" + ) + if self._amax_rowwise is None: + raise RuntimeError( + "Requested row-wise usage, but NVFP4Tensor is missing per tensor" + " row-scaled scale-inverse" + ) + else: + self._rowwise_data = None + self._rowwise_scale_inv = None + self._amax_rowwise = None + + # Update column-scaled data + if columnwise_usage: + if self._columnwise_data is None: + raise RuntimeError( + "Requested column-wise usage, but NVFP4Tensor is missing column-scaled FP8 data" + ) + if self._columnwise_scale_inv is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing column-scaled scale-inverses" + ) + if self._amax_columnwise is None: + raise RuntimeError( + "Requested column-wise usage, " + "but NVFP4Tensor is missing per tensor column-scaled scale-inverse" + ) + else: + self._columnwise_data = None + self._columnwise_scale_inv = None + self._amax_columnwise = None diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py index 23f56da5d..20aba6c2b 100644 --- a/transformer_engine/pytorch/tensor/utils.py +++ b/transformer_engine/pytorch/tensor/utils.py @@ -4,15 +4,18 @@ """Helper functions for using fp8 tensors as weights""" +from typing import Optional, Union, List import torch + import transformer_engine_torch as tex from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv -from .quantized_tensor import QuantizedTensor +from ..quantized_tensor import QuantizedTensor, Quantizer, QuantizedTensorStorage from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from .float8_blockwise_tensor import Float8BlockwiseQTensor, Float8BlockQuantizer from ..optimizers.multi_tensor_apply import multi_tensor_applier +from ..utils import is_non_tn_fp8_gemm_supported def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): @@ -45,7 +48,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor): def cast_master_weights_to_fp8( - model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None + model_weights, + master_weights, + start_offsets, + group, + fsdp_shard_model_weights=None, + manual_post_all_gather_processing=False, ): r"""Helper function to cast master weights to FP8 primary weights. @@ -66,6 +74,11 @@ def cast_master_weights_to_fp8( fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are not sharded. Otherwise, it means that the model weights are sharded and we get target model weights data storage using the FSDP shard model weights. + manual_post_all_gather_processing: bool, default = `False`. + If False, post processing will be automatically triggered during next forward. + If True, the timing of calling post_all_gather_processing is left to the user. + Note that users must call `post_all_gather_processing` if it's set to True, + otherwise the weights won't be updated correctly. """ @@ -126,21 +139,18 @@ def cast_master_weights_to_fp8( f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet" ) + extra_args = [group, use_fsdp_shard_model_weights, manual_post_all_gather_processing] if len(delayed_scaling_params) > 0: - _cast_master_weights_to_fp8_delayed_scaling( - delayed_scaling_params, group, use_fsdp_shard_model_weights - ) + _cast_master_weights_to_fp8_delayed_scaling(delayed_scaling_params, *extra_args) if len(current_scaling_params) > 0: - _cast_master_weights_to_fp8_current_scaling( - current_scaling_params, group, use_fsdp_shard_model_weights - ) + _cast_master_weights_to_fp8_current_scaling(current_scaling_params, *extra_args) if len(blockwise_scaling_params) > 0: - _cast_master_weights_to_fp8_blockwise_scaling( - blockwise_scaling_params, group, use_fsdp_shard_model_weights - ) + _cast_master_weights_to_fp8_blockwise_scaling(blockwise_scaling_params, *extra_args) -def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False): +def _cast_master_weights_to_fp8_delayed_scaling( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): r"""Helper function to cast master weights to FP8 primary weights for delayed scaling. Parameters @@ -157,11 +167,12 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo amaxes, scales, scale_invs = [], [], [] for model_weight, master_weight, start_offset, shard_model_weight_raw in params: - # Reset transpose cache for all model weights. - # We cannot create transpose cache here because users (like megatron) may want to overlap - # the all-gather of model weights and forward process, so the model weight is not updated - # currently. - model_weight._reset_caches() + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight._reset_caches() quantizer = model_weight._get_quantizer() @@ -222,7 +233,9 @@ def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_mo ) -def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False): +def _cast_master_weights_to_fp8_current_scaling( + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False +): r"""Helper function to cast master weights to FP8 primary weights for current scaling. Parameters @@ -300,11 +313,12 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( params, scales ): - # Reset transpose cache for all model weights. - # We cannot create transpose cache here because users (like megatron) may want to overlap - # the all-gather of model weights and forward process, so the model weight is not updated - # currently. - model_weight._reset_caches() + if not manual_post_all_gather_processing: + # Reset transpose cache for all model weights. + # We cannot create transpose cache here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated currently. + model_weight._reset_caches() # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. @@ -331,7 +345,7 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo def _cast_master_weights_to_fp8_blockwise_scaling( - params, group, use_fsdp_shard_model_weights=False + params, group, use_fsdp_shard_model_weights=False, manual_post_all_gather_processing=False ): r"""Helper function to cast master weights to FP8 primary weights for blockwise scaling. @@ -430,11 +444,12 @@ def _cast_master_weights_to_fp8_blockwise_scaling( for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip( params, scales ): - # Clear columnwise data for all model weights. - # We cannot create columnwise data here because users (like megatron) may want to overlap - # the all-gather of model weights and forward process, so the model weight is not updated - # at this moment. - model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) + if not manual_post_all_gather_processing: + # Clear columnwise data for all model weights. + # We cannot create columnwise data here because users (like megatron) may want to + # overlap the all-gather of model weights and forward process, so the model weight is + # not updated at this moment. + model_weight.update_usage(rowwise_usage=True, columnwise_usage=False) # If master weight is None, it means that the master weight of the current model weight # is in other DP ranks. @@ -450,3 +465,37 @@ def _cast_master_weights_to_fp8_blockwise_scaling( tex.fp8_block_scaling_partial_cast( master_weight, model_weight_fragment, scale, h, w, start_offset, block_len, fp8_dtype ) + + +def post_all_gather_processing(model_weights: Union[torch.Tensor, List[torch.Tensor]]): + """ + Post-processing after all-gather for weights in distributed optimizer. + - Float8Tensor: may need to create a transposed view to match backend GEMM. + - Float8BlockwiseQTensor: create column-wise storage. + - Plain pytorch tensor: noop. + """ + if not isinstance(model_weights, list): + model_weights = [model_weights] + for model_weight in model_weights: + if isinstance(model_weight, Float8Tensor): + # Delayed scaling and per-tensor current scaling: if backend does not support + # non-transposed FP8 GEMM, pre-create the transpose. + if not is_non_tn_fp8_gemm_supported(): + model_weight._create_transpose() + elif isinstance(model_weight, Float8BlockwiseQTensor): + # Blockwise scaling: create column-wise storage. + model_weight._create_columnwise() + elif isinstance(model_weight, QuantizedTensor): + raise ValueError(f"post_processing for {type(model_weight)} is not supported") + + +def is_custom(x: Optional[Union[Quantizer, QuantizedTensorStorage]] = None) -> bool: + """Check if an object is custom. + + Returns False if x is a torch.Tensor. + """ + if x is None or isinstance(x, torch.Tensor): + return False + if not isinstance(x, (Quantizer, QuantizedTensorStorage)): + raise AssertionError("Object must be a Quantizer or QuantizedTensorStorage instance") + return hasattr(x, "custom") and x.custom diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 89e43f845..4c7599ad8 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -176,7 +176,12 @@ class TransformerLayer(torch.nn.Module): activation : str, default = 'gelu' Type of activation used in MLP block. Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', - 'silu', and 'swiglu'. + 'silu', 'swiglu', and 'clamped_swiglu'. + activation_params : Optional[dict], default = `None` + Additional parameters for the activation function. + At the moment, only used for 'clamped_swiglu' activation which + supports 'limit' and 'alpha' parameters. You can set these as + `activation_params={'limit': 7.0, 'alpha': 1.702}`. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the @@ -191,6 +196,17 @@ class TransformerLayer(torch.nn.Module): and `DotProductAttention` modules. name: str, default = `None` name of the module, currently used for debugging purposes. + softmax_type: str = {'vanilla', 'off-by-one', 'learnable'}, default = 'vanilla' + softmax type as described in this paper: + `Efficient Streaming Language Models with Attention Sinks + `_. + For a given attention score S = Q*K^T, of shape [b, h, s_q, s_kv], + 'vanilla': S[:,:,:,i] = exp(S[:,:,:,i])/sum(exp(S[:,:,:,:]), dim=-1), + 'off-by-one': S[:,:,:,i] = exp(S[:,:,:,i])/(1 + sum(exp(S[:,:,:,:]), dim=-1)), and + 'learnable': S[:,j,:,i] = exp(S[:,j,:,i])/(exp(alpha[j]) + sum(exp(S[:,j,:,:]), dim=-1)), + where alpha is a learnable parameter in shape [h]. + 'off-by-one' and 'learnable' softmax types are also called sink attention + ('zero sink' and 'learnable sink'). Parallelism parameters ---------------------- @@ -299,6 +315,7 @@ def __init__( ub_bulk_wgrad: bool = True, bias: bool = True, activation: str = "gelu", + activation_params: Optional[dict] = None, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", @@ -306,6 +323,7 @@ def __init__( qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, qk_norm_before_rope: bool = False, + softmax_type: str = "vanilla", ) -> None: super().__init__() @@ -362,6 +380,7 @@ def __init__( self.get_rng_state_tracker = get_rng_state_tracker self.attn_input_format = attn_input_format + self.softmax_type = softmax_type self.name = name @@ -397,6 +416,7 @@ def __init__( "qkv_format": self.attn_input_format, "seq_length": seq_length, "micro_batch_size": micro_batch_size, + "softmax_type": self.softmax_type, } self.self_attention = MultiheadAttention( @@ -461,6 +481,7 @@ def __init__( ub_overlap_rs=ub_overlap_rs, ub_overlap_ag=ub_overlap_ag, activation=activation, + activation_params=activation_params, normalization=normalization, device=device, name=name + ".layernorm_mlp" if name is not None else None, diff --git a/transformer_engine/pytorch/triton/__init__.py b/transformer_engine/pytorch/triton/__init__.py index 76c9b98d0..80766864d 100644 --- a/transformer_engine/pytorch/triton/__init__.py +++ b/transformer_engine/pytorch/triton/__init__.py @@ -2,4 +2,4 @@ # # See LICENSE for license information. -"""Kernels written with OpenAI Triton.""" +"""PyTorch wrappers for Triton kernels.""" diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 6eda84b91..498dd7cdd 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -4,7 +4,7 @@ # # See LICENSE for license information. -"""Efficient Cross Entropy kernels written with OpenAI Triton.""" +"""PyTorch wrapper functions for Cross Entropy Triton kernels.""" from typing import Union from functools import reduce @@ -14,215 +14,12 @@ import torch.distributed as dist import triton -import triton.language as tl from torch.utils.cpp_extension import IS_HIP_EXTENSION - - -@triton.jit -def online_softmax_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes the m/d components on this TP rank for the online softmax. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride (int): The stride of the m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx).to(tl.float32) - else: - X_y = float("-inf") - else: - X_y = float("-inf") - - m_d_X_y_ptr += program_id * m_d_X_y_stride * 3 - - # 3. [Online softmax] first pass: find max + sum - m = float("-inf") # m is the max value. use the notation from the paper - d = 0.0 # d is the sum. use the notation from the paper - - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")).to( - tl.float32 - ) - block_max = tl.max(X_block) - m_new = tl.maximum(m, block_max) - d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new)) - m = m_new - - tl.store(m_d_X_y_ptr, m) - tl.store(m_d_X_y_ptr + m_d_X_y_stride, d) - tl.store(m_d_X_y_ptr + (2 * m_d_X_y_stride), X_y) - - -@triton.jit -def cross_entropy_kernel( - X_ptr, - X_stride, - Y_ptr, - Y_stride, - loss_ptr, - loss_stride, - m_d_X_y_ptr, - m_d_X_y_stride, - rank, - world_size, - ignore_idx, - n_cols, - n_non_ignore, - reduce_loss: tl.constexpr, - label_smoothing: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - """ - This kernel computes both cross entropy loss and the gradient of the input. - - Parameters: - X_ptr: Pointer to input tensor. - X_stride (int): The stride of the input tensor. - Y_ptr: Pointer to target tensor. - Y_stride (int): The stride of the target tensor. - loss_ptr: Pointer to tensor to store the loss. - loss_stride (int): The stride of the loss tensor. - m_d_X_y_ptr: Pointer to m/d/X_y tensor. - m_d_X_y_stride: The stride of m/d/X_y tensor. - rank (int): The rank of this device in the TP group. - world_size (int): The size of world involved in this distributed loss calculation. - ignore_idx (int): Tokens to be ignored for loss and gradient calculation. - n_cols (int): The number of columns in the input tensor. - n_non_ignore (int): The number of non-ignored elements in the batch. - label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - program_id = tl.program_id(0).to(tl.int64) - - # locate the start index - X_ptr += program_id * X_stride - - # Load Y_ptr - Y_ptr += program_id * Y_stride - y = tl.load(Y_ptr) - - if y == ignore_idx: - # set all X_ptr as 0 - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - tl.store(X_ptr + X_offsets, 0.0, mask=X_offsets < n_cols) - return - - loss_ptr += program_id * loss_stride - m_d_X_y_ptr += program_id * 3 * m_d_X_y_stride - - # Need to reduce the m/d/X_y values from other TP ranks - m = tl.load(m_d_X_y_ptr) - d = tl.load(m_d_X_y_ptr + m_d_X_y_stride) - ori_X_y = tl.load(m_d_X_y_ptr + (2 * m_d_X_y_stride)) - - for i in range(1, world_size): - offset = i * 3 * n_non_ignore * m_d_X_y_stride - access_ptr = m_d_X_y_ptr + offset - m_new = tl.load(access_ptr) - d_new = tl.load(access_ptr + m_d_X_y_stride) - X_y_new = tl.load(access_ptr + (2 * m_d_X_y_stride)) - - d = d * tl.exp(m - tl.maximum(m, m_new)) + d_new * tl.exp(m_new - tl.maximum(m, m_new)) - m = tl.maximum(m, m_new) - ori_X_y = tl.maximum(ori_X_y, X_y_new) - - # Label smoothing is a general case of normal cross entropy - scaled_x_sum = 0.0 - eps = label_smoothing / (n_cols * world_size) - - # 4. [Online softmax] second pass: calculate the gradients - # dx_y = (softmax(x_y) - 1) / N - # dx_i = softmax(x_i) / N, i != y - # N is the number of non ignored elements in the batch - # For label smoothing: - # dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y - # dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N - # = dx_i - (1 - label_smoothing) / N - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")) - grad_dtype = X_block.dtype - X_block = X_block.to(tl.float32) - if label_smoothing > 0: - # scale X beforehand to avoid overflow - scaled_x_sum += tl.sum(tl.where(X_offsets < n_cols, -eps * X_block, 0.0)) - # Scale gradients based on reduction mode - # For reduce_loss=True: PyTorch will scale by 1/n_rows, so we need to scale by n_rows/n_non_ignore - # For reduce_loss=False: No additional scaling from PyTorch, so we don't scale here - if reduce_loss: - X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore) - else: - X_block = tl.exp(X_block - m) / d - eps - tl.store(X_ptr + X_offsets, X_block.to(grad_dtype), mask=X_offsets < n_cols) - - # We need tl.debug_barrier() to ensure the new result of X_ptr is written - tl.debug_barrier() - - # 5. Calculate the loss - - # loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X)))) - # = (X_y - max(X)) - log(sum(e ^ (X - max(X)))) - loss = -(ori_X_y - m - tl.log(d)) - - # Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps - # H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p) - # = (1 - label_smoothing) * H(q, p) + eps * sum(logsoftmax(x_i)) - # By using m (global max of xi) and d (sum of e^(xi-m)), we can simplify as: - # = (1 - label_smoothing) * H(q, p) + (-sum(x_i * eps) + label_smoothing * (m + logd)) - # Refer to H(q', p) in section 7 of the paper: https://arxiv.org/pdf/1512.00567 - if label_smoothing > 0: - smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d)) - loss = loss * (1 - label_smoothing) + smooth_loss - - # 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N` - vocab_start_idx = rank * n_cols - vocab_end_idx = (rank + 1) * n_cols - if y >= vocab_start_idx: - if y < vocab_end_idx: - X_y = tl.load(X_ptr + y - vocab_start_idx) - # Apply the same conditional scaling logic for the target token - if reduce_loss: - X_y += -(1 - label_smoothing) / (n_non_ignore) - else: - X_y += -(1 - label_smoothing) - tl.store(X_ptr + y - vocab_start_idx, X_y) - - tl.store(loss_ptr, loss) - +from transformer_engine.common.triton.cross_entropy import ( + online_softmax_kernel, + cross_entropy_kernel, + element_mul_kernel, +) # The optimal maximum block size depends on your hardware, your kernel, and your dtype MAX_FUSED_SIZE = 65536 // 2 @@ -231,44 +28,6 @@ def cross_entropy_kernel( else: NUM_WARPS = 32 -@triton.jit -def element_mul_kernel( - X_ptr, - X_stride, - grad_output_ptr, - grad_output_stride, - n_cols, - BLOCK_SIZE: tl.constexpr, -): - """ - This function multiplies each element of the tensor pointed by X_ptr with the value pointed by grad_output_ptr. - The multiplication is performed in-place on the tensor pointed by X_ptr. - - Parameters: - X_ptr: Pointer to the input tensor. - X_stride (int): The stride of the input tensor. - grad_output_ptr: Pointer to the gradient output value. - n_cols (int): The number of columns in the input tensor. - BLOCK_SIZE (int): The block size for Triton operations. - """ - - # Get the program ID and convert it to int64 to avoid overflow - program_id = tl.program_id(0).to(tl.int64) - - # Locate the start index - X_ptr += program_id * X_stride - - # Load the gradient output value - grad_output_ptr += program_id * grad_output_stride - grad_output = tl.load(grad_output_ptr) - - # Perform the element-wise multiplication - for i in range(0, n_cols, BLOCK_SIZE): - X_offsets = i + tl.arange(0, BLOCK_SIZE) - X_block = tl.load(X_ptr + X_offsets, mask=X_offsets < n_cols) - tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols) - - def cross_entropy_forward( _input: torch.Tensor, target: torch.Tensor, diff --git a/transformer_engine/pytorch/triton/pad.py b/transformer_engine/pytorch/triton/pad.py new file mode 100644 index 000000000..790b8277b --- /dev/null +++ b/transformer_engine/pytorch/triton/pad.py @@ -0,0 +1,43 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""PyTorch wrapper functions for padding Triton kernels.""" + +import torch +import triton + +from transformer_engine.common.triton.pad import zero_pad_kernel + + +def pad_columnwise_scale_inv(inp: torch.Tensor) -> torch.Tensor: + """Pads a tensor assuming it's a columnwise scaling inverse.""" + + assert inp.ndim == 2 + dim0, dim1 = inp.shape + + pad_x = (128 - dim0 % 128) % 128 + pad_y = (4 - dim1 % 4) % 4 + out_x = dim0 + pad_x + out_y = dim1 + pad_y + out = torch.empty((out_x, out_y), device=inp.device, dtype=inp.dtype) + + in_s0, in_s1 = inp.stride() + out_s0, out_s1 = out.stride() + + BLOCK_M, BLOCK_N = 128, 128 + grid = (triton.cdiv(out_x, BLOCK_M), triton.cdiv(out_y, BLOCK_N)) + + zero_pad_kernel[grid]( + inp, + out, + dim0, + dim1, + out_x, + out_y, + in_s0, + in_s1, + out_s0, + out_s1, + ) + return out diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index ceb88108f..da22299fe 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -2,192 +2,23 @@ # # See LICENSE for license information. -"""Permutation kernels written with OpenAI Triton.""" +"""PyTorch wrapper functions for Permutation Triton kernels.""" from typing import Union import torch import triton -import triton.language as tl -from triton.language import core -from triton.language.standard import _log2 - - -# The following three argsort related kernels are adapted from -# the issue https://github.com/triton-lang/triton/issues/3698 - - -@triton.jit -def _compare_and_swap(x, indices, flip, i: tl.constexpr, n_dims: tl.constexpr): - n_outer: tl.constexpr = x.numel >> n_dims - shape: tl.constexpr = [n_outer * (2**i), 2, 2 ** (n_dims - i - 1)] - y = tl.reshape(x, shape) - z = tl.reshape(indices, shape) - - mask = tl.arange(0, 2)[None, :, None] - - l_value = tl.reshape(tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape), x.shape).to( - x.dtype - ) - r_value = tl.reshape(tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape), x.shape).to( - x.dtype - ) - - l_indice = tl.reshape(tl.broadcast_to(tl.sum(z * (1 - mask), 1)[:, None, :], shape), x.shape) - r_indice = tl.reshape(tl.broadcast_to(tl.sum(z * mask, 1)[:, None, :], shape), x.shape) - - idtype = core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True) - - il_value = l_value.to(idtype, bitcast=True) - ir_value = r_value.to(idtype, bitcast=True) - ix = x.to(idtype, bitcast=True) - - flag1 = tl.where(((l_value > r_value) ^ flip) != 0, il_value ^ ir_value, tl.zeros_like(ix)) - ret = ix ^ flag1 - flag2 = tl.where(((l_value > r_value) ^ flip) != 0, l_indice ^ r_indice, tl.zeros_like(ix)) - ind = indices ^ flag2 - - return ret.to(x.dtype, bitcast=True), ind - - -@triton.jit -def _bitonic_merge(x, indices, stage: tl.constexpr, order: tl.constexpr, n_dims: tl.constexpr): - n_outer: tl.constexpr = x.numel >> n_dims - tl.static_assert(stage <= n_dims) - """ - order_type 0 == ascending - order_type 1 == descending - order_type 2 == alternating - """ - if order == 2: - shape: tl.constexpr = [n_outer * (2 ** (n_dims - 1 - stage)), 2, 2**stage] - flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape) - else: - flip = tl.full(x.shape, value=order, dtype=tl.int32) - for i in tl.static_range(stage): - x, indices = _compare_and_swap(x, indices, flip, i + (n_dims - stage), n_dims) - return x, indices - - -@triton.jit -def _argsort(x, indices, n_dims: tl.constexpr): - for i in tl.static_range(1, n_dims + 1): - x, indices = _bitonic_merge(x, indices, i, 2 if i < n_dims else 1, n_dims) - return x, indices - - -@triton.jit -def _row_id_map_pass_1_kernel( - # pointers - routing_map_ptr, - row_id_map_ptr, - workspace_ptr, - # sizes - num_tokens, - # strides - stride_routing_map_token, - stride_routing_map_expert, - stride_row_id_map_token, - stride_row_id_map_expert, - # metas - BLOCK_SIZE: tl.constexpr, -): - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - expert_token_mask = tl.load( - routing_map_ptr + pid_m * stride_routing_map_expert + offset * stride_routing_map_token, - mask=(offset < num_tokens), - other=0, - ).to(tl.int32) - row_id_within_token_block = tl.cumsum(expert_token_mask) * expert_token_mask - tl.store( - row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, - row_id_within_token_block, - mask=offset < num_tokens, - ) - n_tokens_per_block = tl.sum(expert_token_mask) - tl.store(workspace_ptr + pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n, n_tokens_per_block) - - -@triton.jit -def _row_id_map_pass_2_kernel( - # pointers - row_id_map_ptr, - workspace_ptr, - # sizes - num_tokens, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - # metas - WORKSPACE_LOAD_WIDTH: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_m = tl.program_id(0) - pid_n = tl.program_id(1) - chunk_idx = pid_m * tl.cdiv(num_tokens, BLOCK_SIZE) + pid_n - offset = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - row_id_within_token_block = tl.load( - row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, - mask=(offset < num_tokens), - other=0, - ) - - workspace_off = tl.arange(0, WORKSPACE_LOAD_WIDTH) - n_tokens_per_chunk = tl.load(workspace_ptr + workspace_off, mask=workspace_off < chunk_idx) - row_id = tl.where( - row_id_within_token_block == 0, - -1, - row_id_within_token_block + tl.sum(n_tokens_per_chunk) - 1, - ) - tl.store( - row_id_map_ptr + pid_m * stride_row_id_map_expert + offset * stride_row_id_map_token, - row_id, - mask=(offset < num_tokens), - ) - - -@triton.jit -def _row_id_map_pass_3_kernel( - # pointers - row_id_map_ptr, - # sizes - num_experts: tl.constexpr, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - # metas - LOAD_SIZE: tl.constexpr, -): - pid = tl.program_id(0) - n_dims: tl.constexpr = _log2(LOAD_SIZE) - off = tl.arange(0, LOAD_SIZE) - row_id_map = tl.load( - row_id_map_ptr + pid * stride_row_id_map_token + stride_row_id_map_expert * off, - mask=off < num_experts, - other=-1, - ) - n_routed = tl.sum(tl.where(row_id_map != -1, 1, 0)) - indices = off - sorted_map, indices = _argsort(row_id_map, indices, n_dims=n_dims) - tl.store( - row_id_map_ptr + pid * stride_row_id_map_token + off * stride_row_id_map_expert, - sorted_map, - mask=off < n_routed, - ) - tl.store( - row_id_map_ptr - + pid * stride_row_id_map_token - + (num_experts + off) * stride_row_id_map_expert, - indices, - mask=off < n_routed, - ) - tl.store( - row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert, - n_routed, - ) +from transformer_engine.common.triton.permutation import ( + _row_id_map_pass_1_kernel, + _row_id_map_pass_2_kernel, + _row_id_map_pass_3_kernel, + _permute_kernel, + _unpermute_kernel, + _unpermute_bwd_with_merging_probs_kernel, + _make_chunk_sort_map_kernel, + _sort_chunks_by_map_kernel, +) def make_row_id_map( @@ -287,102 +118,6 @@ def make_row_id_map( return row_id_map -@triton.jit -def _permute_kernel( - # pointers - input_ptr, - output_ptr, - row_id_map_ptr, - probs_ptr, - scale_ptr, - permuted_probs_ptr, - permuted_scale_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, - scale_hidden_dim, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_probs_token, - stride_probs_expert, - stride_scale_token, - stride_scale_hidden, - stride_permuted_probs_token, - stride_permuted_scale_token, - stride_permuted_scale_hidden, - # metas - PERMUTE_PROBS: tl.constexpr, - PERMUTE_SCALE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - pid_t = tl.program_id(0) - pid_h = tl.program_id(1) - cur_off = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = cur_off < hidden_size - input_off = pid_t * stride_input_token + cur_off * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - if PERMUTE_SCALE: - mask_scale = cur_off < scale_hidden_dim - scale_off = pid_t * stride_scale_token + cur_off * stride_scale_hidden - scale = tl.load(scale_ptr + scale_off, mask=mask_scale) - n_routed = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + num_experts * 2 * stride_row_id_map_expert - ) - for idx in tl.range(n_routed): - dst_row = tl.load( - row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ) - output_off = dst_row * stride_output_token + cur_off * stride_output_hidden - if PERMUTE_SCALE: - permuted_scale_off = ( - dst_row * stride_permuted_scale_token + cur_off * stride_permuted_scale_hidden - ) - tl.store(permuted_scale_ptr + permuted_scale_off, scale, mask=mask_scale) - if PERMUTE_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - prob_off = pid_t * stride_probs_token + expert_idx * stride_probs_expert - prob = tl.load(probs_ptr + prob_off) - if pid_h == 0: - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - if prob == 0.0: - # for routing_map padding - # dst_row != -1 and prob == 0.0 means that this slot is padded - tl.store(output_ptr + output_off, 0.0, mask=mask) - else: - tl.store(output_ptr + output_off, inp, mask=mask) - else: - tl.store(output_ptr + output_off, inp, mask=mask) - - -try: - _permute_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_permute_kernel) -except RuntimeError: - pass - - def permute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -462,115 +197,6 @@ def permute_with_mask_map( return output, permuted_scale, permuted_probs -@triton.jit -def _unpermute_kernel( - # pointers - input_ptr, - output_ptr, - row_id_map_ptr, - merging_probs_ptr, - permuted_probs_ptr, - unpermuted_probs_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_merging_probs_token, - stride_merging_probs_expert, - stride_permuted_probs_token, - stride_unpermuted_probs_token, - stride_unpermuted_probs_expert, - # metas - PROBS_LOAD_WIDTH: tl.constexpr, - WITH_MERGING_PROBS: tl.constexpr, - PERMUTE_PROBS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - data_type = input_ptr.dtype.element_ty - compute_type = tl.float32 - - pid_t = tl.program_id(0) - pid_h = tl.program_id(1) - current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - if PERMUTE_PROBS: - # write 0.0 to probs_grad that are not routed - if pid_h == 0: - map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) - unpermuted_prob_off = ( - pid_t * stride_unpermuted_probs_token - + stride_unpermuted_probs_expert * map_load_off - ) - tl.store( - unpermuted_probs_ptr + unpermuted_prob_off, 0.0, mask=map_load_off < num_experts - ) - accumulator = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - n_routed = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + num_experts * 2 * stride_row_id_map_expert - ) - for idx in tl.range(n_routed): - src_row = tl.load( - row_id_map_ptr + pid_t * stride_row_id_map_token + idx * stride_row_id_map_expert - ) - input_off = src_row * stride_input_token + current_offset * stride_input_hidden - inp = tl.load(input_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - if WITH_MERGING_PROBS: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - merging_prob_off = ( - pid_t * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - inp *= merging_prob - accumulator += inp - if PERMUTE_PROBS: - if pid_h == 0: - expert_idx = tl.load( - row_id_map_ptr - + pid_t * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - unpermuted_prob_off = ( - pid_t * stride_unpermuted_probs_token - + expert_idx * stride_unpermuted_probs_expert - ) - permuted_prob_off = src_row * stride_permuted_probs_token - prob = tl.load(permuted_probs_ptr + permuted_prob_off) - tl.store(unpermuted_probs_ptr + unpermuted_prob_off, prob) - accumulator = accumulator.to(data_type) - output_off = pid_t * stride_output_token + current_offset * stride_output_hidden - tl.store(output_ptr + output_off, accumulator, mask=mask) - - -try: - _unpermute_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_unpermute_kernel) -except RuntimeError: - pass - - def unpermute_with_mask_map( inp: torch.Tensor, row_id_map: torch.Tensor, @@ -637,108 +263,6 @@ def unpermute_with_mask_map( return output, unpermuted_probs -@triton.jit -def _unpermute_bwd_with_merging_probs_kernel( - # pointers - fwd_output_grad_ptr, - fwd_input_grad_ptr, - fwd_input_ptr, - merging_probs_ptr, - merging_probs_grad_ptr, - row_id_map_ptr, - # sizes - num_experts: tl.constexpr, - hidden_size: tl.constexpr, - # strides - stride_row_id_map_token, - stride_row_id_map_expert, - stride_fwd_output_grad_token, - stride_fwd_output_grad_hidden, - stride_fwd_input_grad_token, - stride_fwd_input_grad_hidden, - stride_fwd_input_token, - stride_fwd_input_hidden, - stride_merging_probs_token, - stride_merging_probs_expert, - stride_merging_probs_grad_token, - stride_merging_probs_grad_expert, - # metas - PROBS_LOAD_WIDTH: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - data_type = fwd_output_grad_ptr.dtype.element_ty - compute_type = tl.float32 - - pid = tl.program_id(0) - map_load_off = tl.arange(0, PROBS_LOAD_WIDTH) - token_probs_grad_off = ( - pid * stride_merging_probs_grad_token + stride_merging_probs_grad_expert * map_load_off - ) - tl.store(merging_probs_grad_ptr + token_probs_grad_off, 0.0, mask=map_load_off < num_experts) - n_routed = tl.load( - row_id_map_ptr + pid * stride_row_id_map_token + num_experts * 2 * stride_row_id_map_expert - ) - for idx in tl.range(n_routed): - dst_row = tl.load( - row_id_map_ptr + pid * stride_row_id_map_token + idx * stride_row_id_map_expert - ) - expert_idx = tl.load( - row_id_map_ptr - + pid * stride_row_id_map_token - + (num_experts + idx) * stride_row_id_map_expert - ) - prob_grad_accum = tl.zeros((BLOCK_SIZE,), dtype=compute_type) - current_start = 0 - while current_start < hidden_size: - current_offset = current_start + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_off = ( - pid * stride_fwd_output_grad_token + current_offset * stride_fwd_output_grad_hidden - ) - inp = tl.load(fwd_output_grad_ptr + input_off, mask=mask) - inp = inp.to(compute_type) - merging_prob_off = ( - pid * stride_merging_probs_token + expert_idx * stride_merging_probs_expert - ) - merging_prob = tl.load(merging_probs_ptr + merging_prob_off).to(compute_type) - output = inp * merging_prob - output = output.to(data_type) - output_off = ( - dst_row * stride_fwd_input_grad_token - + current_offset * stride_fwd_input_grad_hidden - ) - tl.store(fwd_input_grad_ptr + output_off, output, mask=mask) - - fwd_input_off = ( - dst_row * stride_fwd_input_token + current_offset * stride_fwd_input_hidden - ) - fwd_input = tl.load(fwd_input_ptr + fwd_input_off, mask=mask) - prob_grad_accum += fwd_input.to(compute_type) * inp - current_start += BLOCK_SIZE - probs_grad = tl.sum(prob_grad_accum).to(merging_probs_grad_ptr.dtype.element_ty) - probs_grad_off = ( - pid * stride_merging_probs_grad_token + expert_idx * stride_merging_probs_grad_expert - ) - tl.store(merging_probs_grad_ptr + probs_grad_off, probs_grad) - - -try: - _unpermute_bwd_with_merging_probs_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_unpermute_bwd_with_merging_probs_kernel) -except RuntimeError: - pass - - def unpermute_with_mask_map_bwd_with_merging_probs( fwd_output_grad: torch.Tensor, row_id_map: torch.Tensor, @@ -804,47 +328,6 @@ def unpermute_with_mask_map_bwd_with_merging_probs( return act_grad, merging_probs_grad -@triton.jit -def _make_chunk_sort_map_kernel( - # pointers - split_sizes_ptr, - sorted_indices_ptr, - dst_rows_ptr, - # sizes - num_splits: tl.constexpr, - # metas - IDX_LOAD_WIDTH: tl.constexpr, -): - pid = tl.program_id(0) - - load_split_offset = tl.arange(0, IDX_LOAD_WIDTH) - sorted_indices = tl.load( - sorted_indices_ptr + load_split_offset, mask=load_split_offset < num_splits - ) - - # get chunk idx of the current token in the input tensor - input_split_sizes = tl.load( - split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 - ).to(tl.int32) - input_split_sizes_cumsum = tl.cumsum(input_split_sizes) - input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) - input_chunk_idx = tl.sum(input_split_sizes_mask) - input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) - in_chunk_offset = pid - input_split_sizes_presum - - # get chunk idx of the current token in the output tensor - output_chunk_mask = tl.where(sorted_indices == input_chunk_idx, 1, 0) - output_chunk_idx = tl.argmax(output_chunk_mask, axis=-1) - - # make row_id_map - output_split_sizes = tl.load( - split_sizes_ptr + sorted_indices, mask=load_split_offset < num_splits - ).to(tl.int32) - output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) - dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset - tl.store(dst_rows_ptr + pid, dst_row) - - def make_chunk_sort_map( split_sizes: torch.Tensor, sorted_indices: torch.Tensor, @@ -877,67 +360,6 @@ def make_chunk_sort_map( return row_id_map -@triton.jit -def _sort_chunks_by_map_kernel( - # pointers - input_ptr, - output_ptr, - row_id_map_ptr, - probs_ptr, - permuted_probs_ptr, - # sizes - hidden_size: tl.constexpr, - # strides - stride_input_token, - stride_input_hidden, - stride_output_token, - stride_output_hidden, - stride_probs_token, - stride_permuted_probs_token, - # metas - PERMUTE_PROBS: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - FORWARD: tl.constexpr, -): - pid_t = tl.program_id(0) - pid_h = tl.program_id(1) - if FORWARD: - src_row = pid_t - dst_row = tl.load(row_id_map_ptr + pid_t) - else: - src_row = tl.load(row_id_map_ptr + pid_t) - dst_row = pid_t - current_offset = pid_h * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - mask = current_offset < hidden_size - input_offsets = src_row * stride_input_token + current_offset * stride_input_hidden - output_offsets = dst_row * stride_output_token + current_offset * stride_output_hidden - inp = tl.load(input_ptr + input_offsets, mask=mask) - tl.store(output_ptr + output_offsets, inp, mask=mask) - if PERMUTE_PROBS: - if pid_h == 0: - prob_off = src_row * stride_probs_token - prob = tl.load(probs_ptr + prob_off) - permuted_prob_off = dst_row * stride_permuted_probs_token - tl.store(permuted_probs_ptr + permuted_prob_off, prob) - - -try: - _sort_chunks_by_map_kernel = triton.autotune( - configs=[ - triton.Config({"BLOCK_SIZE": 64}), - triton.Config({"BLOCK_SIZE": 128}), - triton.Config({"BLOCK_SIZE": 256}), - triton.Config({"BLOCK_SIZE": 512}), - triton.Config({"BLOCK_SIZE": 1024}), - triton.Config({"BLOCK_SIZE": 2048}), - triton.Config({"BLOCK_SIZE": 4096}), - ], - key=["hidden_size"], - )(_sort_chunks_by_map_kernel) -except RuntimeError: - pass - - def sort_chunks_by_map( inp: torch.Tensor, row_id_map: torch.Tensor, diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index b6a7270a3..eae6bb79c 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -10,11 +10,11 @@ from ..utils import is_non_tn_fp8_gemm_supported -from ..tensor._internal.float8_tensor_base import Float8TensorBase +from ..tensor.storage.float8_tensor_storage import Float8TensorStorage from .cast_transpose import te_cast_transpose_mxfp8_triton, te_cast_transpose_noop_triton, te_dequantize_mxfp8_triton import transformer_engine_torch as tex -from ..tensor.quantized_tensor import QuantizedTensor, Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase +from ..quantized_tensor import QuantizedTensor, Quantizer +from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage @functools.lru_cache(maxsize=None) def _empty_tensor() -> torch.Tensor: @@ -72,7 +72,7 @@ def te_quantize_triton( _setup_conditional_transpose_storage(out) else: out = quantizer.make_empty(input_tensor.shape, dtype=fake_tensor_type) - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): _setup_conditional_transpose_storage(out) else: # Create a QuantizedTensor from the provided output tensor @@ -82,11 +82,11 @@ def te_quantize_triton( if noop_flag is None: noop_flag = _empty_tensor() # if it's mxfp8, we'll check if both rowwise and columnwise data are none - if (isinstance(out, MXFP8TensorBase) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorBase) and out.size().numel() == 0): + if (isinstance(out, MXFP8TensorStorage) and out._rowwise_data is None and out._columnwise_data is None) or (not isinstance(out, MXFP8TensorStorage) and out.size().numel() == 0): # Return empty output if the quantized tensor has no elements return out - if isinstance(out, Float8TensorBase): + if isinstance(out, Float8TensorStorage): if input_tensor.nelement() > 0: if not out._transpose_invalid: quantizer = out._get_quantizer() @@ -117,7 +117,7 @@ def te_quantize_triton( else: out.remove_caches() #Make sure to remove transpose if it is marked as invalid out = tex.quantize(input_tensor, quantizer, out, noop_flag) - elif isinstance(out, MXFP8TensorBase): + elif isinstance(out, MXFP8TensorStorage): te_cast_transpose_mxfp8_triton(input_tensor, out) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(out).__name__}'") @@ -125,9 +125,9 @@ def te_quantize_triton( return out def te_dequantize_triton(input, dtype: tex.DType): - if isinstance(input, MXFP8TensorBase): + if isinstance(input, MXFP8TensorStorage): return te_dequantize_mxfp8_triton(input, dtype) - elif isinstance(input, Float8TensorBase): + elif isinstance(input, Float8TensorStorage): return tex.dequantize(input, dtype) else: raise NotImplementedError(f"Not implemented for tensor type: '{type(input).__name__}'") diff --git a/transformer_engine/pytorch/triton_kernels/layernorm.py b/transformer_engine/pytorch/triton_kernels/layernorm.py index 3baa64697..86b7b46c7 100644 --- a/transformer_engine/pytorch/triton_kernels/layernorm.py +++ b/transformer_engine/pytorch/triton_kernels/layernorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information @@ -10,7 +10,7 @@ from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ..constants import TE_DType from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer from ..triton_kernels.cast import te_quantize_triton import triton import triton.language as tl diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index 9f152582e..1ca6183c9 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -1,4 +1,4 @@ -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information import torch @@ -13,7 +13,7 @@ te_dtype_to_triton_dtype, ) from .common import get_fp8_max -from ..tensor.quantized_tensor import Quantizer +from ..quantized_tensor import Quantizer import transformer_engine_torch as tex def dg_tmp_rows(x, sm_margin=None): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index d124fbeaf..12c62437d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -14,11 +14,16 @@ import torch from torch.utils.cpp_extension import IS_HIP_EXTENSION -import transformer_engine.pytorch.cpp_extensions as ext from . import torch_version +from .quantized_tensor import Quantizer from ..debug.pytorch.debug_quantization import DebugQuantizedTensor +__all__ = ["get_device_compute_capability", "get_cudnn_version", "is_bf16_available"] + +if IS_HIP_EXTENSION: + __all__.extend(["is_mi200", "is_mi308", "is_fp8_fnuz"]) + def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: """Check if any of the given tensors require gradient.""" for tensor in tensors: @@ -187,7 +192,7 @@ def combine_tensors( num_tensors = len(tensors) new_shape = list(tensors[0].shape) new_shape.insert(dim, num_tensors) - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor if isinstance(tensors[0], Float8Tensor): new_stride = list(tensors[0]._data.stride()) @@ -227,14 +232,16 @@ def forward( # pylint: disable=missing-function-docstring ctx.split_dim = split_dim ctx.split_size_or_sections = split_size_or_sections - from transformer_engine.pytorch.float8_tensor import Float8Tensor - from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.storage.float8_tensor_storage import ( + Float8TensorStorage, + ) - if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance( + if isinstance(mixed_x_layer, Float8TensorStorage) and not isinstance( mixed_x_layer, Float8Tensor ): return tuple( - Float8TensorBase( + Float8TensorStorage( fp8_scale_inv=mixed_x_layer._scale_inv, fp8_dtype=mixed_x_layer._fp8_dtype, data=x.squeeze(split_dim) if squeeze else x, @@ -279,7 +286,7 @@ def backward(ctx, *grad_outputs): split_sizes = [ctx.split_size_or_sections] * len(grad_outputs) dims = len(grad_outputs[0].shape) split_dim = (ctx.split_dim + dims) % dims - from transformer_engine.pytorch.float8_tensor import Float8Tensor + from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor if isinstance(grad_outputs[0], Float8Tensor): noop_ok = True @@ -463,6 +470,15 @@ def is_fp8_fnuz(): get_torch_float8_e4m3_type = lambda: torch.float8_e4m3fnuz if is_fp8_fnuz() else torch.float8_e4m3fn get_torch_float8_e5m2_type = lambda: torch.float8_e5m2fnuz if is_fp8_fnuz() else torch.float8_e5m2 +def assert_dim_for_all_gather( + tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer +) -> None: + """Assert that tensor dimensions are supported for all-gather""" + if with_all_gather: + assert quantizer.is_quantizable(tensor), ( + "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ + ) + def is_bf16_compatible() -> None: if IS_HIP_EXTENSION: # only MI200 and newer machines support bf16 @@ -476,6 +492,28 @@ def is_bf16_compatible() -> None: """ return torch.cuda.get_device_capability()[0] >= 8 +def is_bf16_available(return_reason: bool = False) -> Union[bool, Tuple[bool, str]]: + """ + Determine whether bfloat16 (BF16) computation is supported on the current device. + + Parameters + ---------- + return_reason : bool, optional + If ``False`` (default), return only a boolean indicating BF16 availability. + If ``True``, return a tuple ``(is_available, reason)`` where ``reason`` provides + a human-readable explanation when BF16 is not available. When BF16 is available, + the reason will be an empty string. + + """ + available = is_bf16_compatible() + if not return_reason: + return available + + reason = ( + "" if available else "BF16 support requires a GPU with compute capability 8.0 or higher." + ) + return available, reason + @functools.lru_cache(maxsize=None) def is_non_tn_fp8_gemm_supported() -> bool: @@ -492,6 +530,8 @@ def get_cudnn_version() -> Tuple[int, int, int]: # ROCm fused attn does not use cudnn, return high numbers to avoid tests filtering out if IS_HIP_EXTENSION: return (99, 0, 0) + import transformer_engine.pytorch.cpp_extensions as ext + encoded_version = ext.get_cudnn_version() major_version_magnitude = 1000 if encoded_version < 90000 else 10000 major, encoded_version = divmod(encoded_version, major_version_magnitude)